【问题标题】:How to transform output of neural network and still train?如何转换神经网络的输出并仍然训练?
【发布时间】:2021-12-30 14:11:38
【问题描述】:

我有一个输出output 的神经网络。我想在损失和反向传播发生之前转换output

这是我的通用代码:

with torch.set_grad_enabled(training):
                  outputs = net(x_batch[:, 0], x_batch[:, 1]) # the prediction of the NN
                  # My issue is here:
                  outputs = transform_torch(outputs)
                  loss = my_loss(outputs, y_batch)

                  if training:
                      scheduler.step()
                      loss.backward()
                      optimizer.step()

我有一个转换函数,我将输出通过:

def transform_torch(predictions):
    torch_dimensions = predictions.size()
    torch_grad = predictions.grad_fn
    cuda0 = torch.device('cuda:0')
    new_tensor = torch.ones(torch_dimensions, dtype=torch.float64, device=cuda0, requires_grad=True)
    for i in range(int(len(predictions))):
      a = predictions[i]
      # with torch.no_grad(): # Note: no training happens if this line is kept in
      new_tensor[i] = torch.flip(torch.cumsum(torch.flip(a, dims = [0]), dim = 0), dims = [0])
    return new_tensor

我的问题是在倒数第二行出现错误:

RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

有什么建议吗?我已经尝试过使用“with torch.no_grad():”(已评论),但这会导致训练效果很差,而且我相信梯度在转换函数后无法正确反向传播。

谢谢!

【问题讨论】:

  • 转换中的a 是什么?
  • 我修复了它 - a = predictions[i]。我在移除 cmets 时不小心把它遗漏了。感谢您的澄清。

标签: python deep-learning neural-network pytorch backpropagation


【解决方案1】:

关于问题所在的错误是非常正确的 - 当您使用 requires_grad = True 创建一个新张量时,您会在图中创建一个叶节点(就像模型的参数一样)并且不允许进行就地操作就可以了。

解决方法很简单,不需要提前创建new_tensor。它不应该是叶节点;即时创建它

new_tensor = [ ]
for i in range(int(len(predictions))):
    a = predictions[i]
    new_tensor.append(torch.flip(torch.cumsum(torch.flip(a, ...), ...), ...))

new_tensor = torch.stack(new_tensor, 0)    

new_tensor 将从predictions 继承所有属性,例如dtypedevice,并且已经拥有require_grad = True

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2011-04-07
    • 1970-01-01
    • 2020-04-28
    • 2010-11-20
    • 2019-09-15
    • 1970-01-01
    • 1970-01-01
    • 2020-12-18
    相关资源
    最近更新 更多