【发布时间】:2021-09-30 04:50:23
【问题描述】:
虽然 autograd 的 hvp 工具似乎对函数非常有效,但一旦涉及模型,Hessian-vector 积似乎会变为 0。一些代码。
首先,我定义了世界上最简单的模型:
class SimpleMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, out_dim),
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
然后,损失函数:
def objective(x):
return torch.sum(0.25 * torch.sum(x)**4)
我们实例化它:
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)
最后,我将定义一个“前向”函数(与模型的前向函数不同),它将作为我们要分析的完整模型+损失:
def forward(*params_list):
for param_val, model_param in zip(params_list, mlp.parameters()):
model_param.data = param_val
x = torch.ones((Arows,))
return objective(mlp(x))
这会将一个向量传递给单层“mlp”,并将其传递给我们的二次损失。
现在,我尝试计算:
v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
numel = param.numel()
v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
idx += numel
最后:
param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln = torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)
但是,可惜的是,soln 中的 Hessian-Vector Product 都是 0。发生了什么?
【问题讨论】: