【问题标题】:how to get jacobian with pytorch for log probability of multivariate normal distribution如何使用 pytorch 获取 jacobian 以获得多元正态分布的对数概率
【发布时间】:2020-02-24 23:16:24
【问题描述】:

我从多元正态分布中抽取样本,并希望获得它们的对数概率相对于均值的梯度。由于样本很多,这需要一个雅可比行列:

import torch

mu = torch.ones((2,), requires_grad=True)
sigma = torch.eye(2)
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, sigma)

num_samples=10
samples = dist.sample((num_samples,))
logprobs = dist.log_prob(samples)

现在我想得到logprobs 中每个条目相对于mu 中每个条目的导数。

一个简单的解决方案是python循环:

grads = []
for logprob in logprobs:
    grad = torch.autograd.grad(logprob, mu, retain_graph=True)
    grads.append(grad)

如果您堆叠 grads,则结果是所需的雅可比行列式。是否还有对此的内置和矢量化支持?

相关问题/互联网资源:

这是一个很大的话题,有很多相关的帖子。不过,我认为这个具体问题(关于分布)尚未得到解答:

  • 这个问题与我的基本相同(但没有示例代码和解决方案尝试),遗憾的是没有答案:Pytorch custom function jacobian gradient

  • 这个问题显示了 pytorch 中雅可比的计算,但我认为该解决方案不适用于我的问题:Pytorch most efficient Jacobian/Hessian calculation 它需要以一种似乎与分布不兼容的方式堆叠输入。我无法让它工作。

  • 这个要点有一些 Jacobians 的代码 sn-ps。原则上,它们与上述问题中的方法类似。

【问题讨论】:

  • 不要认为它是直截了当的,因为 Torch 实现了反向模式 AD,因此期望输出是标量(在你的情况下它是矢量)。您可以尝试通过复制将 \mu 设为 (10,2),因此 grad 可能会给您正确的 jacobian
  • 我同意,使用 pytorch 的 autograd 没有循环可能没有直接的方法来获取雅可比行列式。也就是说,在这种特殊情况下,您的雅可比行列式有一个相对简单的封闭形式,即 grads = torch.mm(samples - mu, torch.inverse(sigma))(您可以使用 torch.solve 获得更稳定的版本,它不会直接反转)。根据您实际问题的复杂性,也许硬编码雅可比表达式可能是要走的路?
  • @jodag,你说得对,这里的雅可比很容易计算。但这是 stackoverflow 的最小可行示例。真正的代码要复杂得多。

标签: python pytorch


【解决方案1】:

PyTorch 1.5.1 引入了torch.autograd.functional.jacobian 函数。这计算了函数 w.r.t 的雅可比行列式。输入张量。由于jacobian 需要一个python 函数作为第一个参数,因此使用它需要进行一些代码重组。

import torch

torch.manual_seed(0)    # for repeatable results

mu = torch.ones((2,), requires_grad=True)
sigma = torch.eye(2)
num_samples = 10

def f(mu):
    dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, sigma)
    samples = dist.sample((num_samples,))
    logprobs = dist.log_prob(samples)
    return logprobs

grads = torch.autograd.functional.jacobian(f, mu)

print(grads)
tensor([[-1.1258, -1.1524],
        [-0.2506, -0.4339],
        [ 0.5988, -1.5551],
        [-0.3414,  1.8530],
        [ 0.4681, -0.1577],
        [ 1.4437,  0.2660],
        [ 1.3894,  1.5863],
        [ 0.9463, -0.8437],
        [ 0.9318,  1.2590],
        [ 2.0050,  0.0537]])

【讨论】:

  • 很好,感谢您回来并结束这一切:)
猜你喜欢
  • 1970-01-01
  • 2021-03-13
  • 2021-06-27
  • 1970-01-01
  • 2018-12-19
  • 2021-07-26
  • 1970-01-01
  • 2020-05-07
  • 1970-01-01
相关资源
最近更新 更多