【问题标题】:RuntimeError: Trying to backward through the graph a second timeRuntimeError:尝试第二次向后遍历图形
【发布时间】:2021-11-18 03:53:48
【问题描述】:

我正在尝试使用“pyro”来训练“可训练的伯努利分布”。

我想使用 NLL 损失来训练伯努利分布的参数(获胜概率)。

train_data 是 one-hot 编码的稀疏矩阵(2034,19475),train_labels 有 4 个值(4 个类别,[0,1,2,3])。

import torch
import pyro
pyd = pyro.distributions

print("torch version:", torch.__version__)
print("pyro version:", pyro.__version__)

import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(123)


### 0. define Negative Log Likelihood(NLL) loss function
def nll(x_train, distribution):    
    return -torch.mean(distribution.log_prob(torch.tensor(x_train, dtype=torch.float)))


### 1. initialize bernoulli distribution(trainable distribution)
train_vars = (pyd.Uniform(low=torch.FloatTensor([0.01]),
                          high=torch.FloatTensor([0.1])).rsample([train_data.shape[-1]]).squeeze())
distribution = pyd.Bernoulli(probs=train_vars)

### 2. initialize 'label 0' data
class_mask = (train_labels==0)
class_data = train_data[class_mask, :]

### 3. initialize optimizer
optim = torch.optim.Adam([train_vars])

train_vars.requires_grad=True

### 4. train loop
for i in range(0,100):
    
    loss = nll(class_data, distribution)
    
    loss.backward()

当我运行此代码时,我得到如下所示的 RUNTIME ERROR..

我应该如何处理这种错误情况?

您的评论将非常非常感谢。

torch version: 1.9.0+cu102
pyro version: 1.7.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-269-0081bb1bb843> in <module>
     25     loss = nll(class_data, distribution)
     26 
---> 27     loss.backward()
     28 

/nf/yes/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    253                 create_graph=create_graph,
    254                 inputs=inputs)
--> 255         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    256 
    257     def register_hook(self, hook):

/nf/yes/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    145         retain_graph = create_graph
    146 
--> 147     Variable._execution_engine.run_backward(
    148         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    149         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

【问题讨论】:

    标签: python pytorch autograd pyro


    【解决方案1】:

    你需要移动

    distribution = pyd.Bernoulli(probs=train_vars)
    

    在循环内部,因为它使用train_vars,而requires_grad

    【讨论】:

    • 谢谢!它有效.. :) 它对我有很大帮助。
    猜你喜欢
    • 1970-01-01
    • 2021-03-11
    • 2020-07-15
    • 2020-11-07
    • 2021-11-19
    • 2020-10-06
    • 1970-01-01
    • 2018-06-24
    • 2023-02-26
    相关资源
    最近更新 更多