【发布时间】:2021-07-12 06:01:43
【问题描述】:
我有一个预训练的 PyTorch 模型。我需要使用此模型计算损失相对于网络输入的梯度(无需再次训练,仅使用预训练模型)。
我写了以下代码,但我不确定它是否正确。
test_X, test_y = load_data(mode='test')
testset_original = MyDataset(test_X, test_y, transform=default_transform)
testloader = DataLoader(testset_original, batch_size=32, shuffle=True)
model = MyModel(device=device).to(device)
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
gradient_losses = []
for i, data in enumerate(testloader):
inputs, labels = data
inputs= inputs.to(device)
labels = labels.to(device)
inputs.requires_grad = True
output = model(inputs)
loss = loss_function(output)
loss.backward()
gradient_losses.append(inputs.grad)
我的问题是,这个列表 gradient_losses 是否真的存储了我想要存储的内容?如果不是,那么正确的做法是什么?
【问题讨论】: