【问题标题】:How to get the output gradient w.r.t input如何获得输出梯度w.r.t输入
【发布时间】:2022-01-14 07:49:04
【问题描述】:

我在获取输入的输出梯度时遇到了一些问题。 是简单的mnist模型。

for num,(sample_img, sample_label) in enumerate(mnist_test):
    if num == 1:
        break

    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    prediction = model(sample_img.unsqueeze(dim=0))
    cost = criterion(prediction, torch.tensor([sample_label]).to(device))
    optimizer.zero_grad()
    cost.backward()
    print(sample_label)
    print(sample_img.shape)

    plt.imshow(sample_img.detach().cpu().squeeze(),cmap='gray')
    plt.show()

print(sample_img.grad)

sample_img.grad 为无

【问题讨论】:

    标签: input pytorch output gradient mnist


    【解决方案1】:

    如果您需要计算输入的梯度,您可以通过调用 sample_img.requires_grad_() 或设置 sample_img.requires_grad = True 来完成,如 cmets 中所建议的那样。

    这是一个小例子:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import matplotlib.pyplot as plt
    
    
    model = nn.Sequential(  # a dummy model
        nn.Conv2d(1, 1, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten()
    )
    
    sample_img = torch.rand(1, 5, 5)  # a dummy input
    sample_label = 0
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    device = "cpu"
    
    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    
    prediction = model(sample_img.unsqueeze(dim=0))
    cost = criterion(prediction, torch.tensor([sample_label]).to(device))
    optimizer.zero_grad()
    cost.backward()
    print(sample_label)
    print(sample_img.shape)
    
    plt.imshow(sample_img.detach().cpu().squeeze(), cmap='gray')
    plt.show()
    
    print(sample_img.grad.shape)
    print(sample_img.grad)
    

    另外,如果你不需要模型的梯度,你可以关闭他们的梯度要求:

    for param in model.parameters():
        param.requires_grad = False
    

    【讨论】:

    猜你喜欢
    • 2022-01-13
    • 1970-01-01
    • 2019-11-19
    • 2020-02-27
    • 2018-10-14
    • 2019-01-10
    • 2017-01-26
    • 2019-04-14
    相关资源
    最近更新 更多