【问题标题】:Trouble with minimal hvp on pytorch modelpytorch模型上最小hvp的问题
【发布时间】:2021-09-30 04:50:23
【问题描述】:

虽然 autograd 的 hvp 工具似乎对函数非常有效,但一旦涉及模型,Hessian-vector 积似乎会变为 0。一些代码。

首先,我定义了世界上最简单的模型:

class SimpleMLP(nn.Module):
  def __init__(self, in_dim, out_dim):
      super().__init__()
      self.layers = nn.Sequential(
        nn.Linear(in_dim, out_dim),
      )
      
  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)

然后,损失函数:

def objective(x):
  return torch.sum(0.25 * torch.sum(x)**4)

我们实例化它:

Arows = 2
Acols = 2

mlp = SimpleMLP(Arows, Acols)

最后,我将定义一个“前向”函数(与模型的前向函数不同),它将作为我们要分析的完整模型+损失:

def forward(*params_list):
  for param_val, model_param in zip(params_list, mlp.parameters()):
    model_param.data = param_val
 
  x = torch.ones((Arows,))
  return objective(mlp(x))

这会将一个向量传递给单层“mlp”,并将其传递给我们的二次损失。

现在,我尝试计算:

v = torch.ones((6,))
v_tensors = []
idx = 0
#this code "reshapes" the v vector as needed
for i, param in enumerate(mlp.parameters()):
  numel = param.numel()
  v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
  idx += numel

最后:

param_tensors = tuple(mlp.parameters())
reshaped_v = tuple(v_tensors)
soln =  torch.autograd.functional.hvp(forward, param_tensors, v=reshaped_v)

但是,可惜的是,soln 中的 Hessian-Vector Product 都是 0。发生了什么?

【问题讨论】:

    标签: python pytorch autograd


    【解决方案1】:

    发生的情况是 stricthvp() 函数中默认为 False,0 的张量作为 Hessian 向量积返回,而不是错误 (source)。

    如果您尝试使用strict=True,则会返回错误RuntimeError: The output of the user-provided function is independent of input 0. This is not allowed in strict mode.。当我查看完整的错误时,我怀疑这个错误来自_check_requires_grad(jac, "jacobian", strict=strict),这表明雅可比jacNone

    更新:

    以下是一个完整的工作示例:

    import torch
    from torch import nn
    
    # your loss function
    def objective(x):
        return torch.sum(0.25 * torch.sum(x)**4)
    
    # Following are utilities to make nn.Module functional
    # borrowed from the link I posted in comment
    def del_attr(obj, names):
        if len(names) == 1:
            delattr(obj, names[0])
        else:
            del_attr(getattr(obj, names[0]), names[1:])
    
    def set_attr(obj, names, val):
        if len(names) == 1:
            setattr(obj, names[0], val)
        else:
            set_attr(getattr(obj, names[0]), names[1:], val)
    
    def make_functional(mod):
        orig_params = tuple(mod.parameters())
        # Remove all the parameters in the model
        names = []
        for name, p in list(mod.named_parameters()):
            del_attr(mod, name.split("."))
            names.append(name)
        return orig_params, names
    
    def load_weights(mod, names, params):
        for name, p in zip(names, params):
            set_attr(mod, name.split("."), p)
    
    # your forward function with update
    def forward(*new_params):
        # this line replace your for loop
        load_weights(mlp, names, new_params)
    
        x = torch.ones((Arows,))
        out = mlp(x)
        loss = objective(out)
        return loss
    
    # your simple MLP model
    class SimpleMLP(nn.Module):
        def __init__(self, in_dim, out_dim):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(in_dim, out_dim),
            )
    
        def forward(self, x):
            '''Forward pass'''
            return self.layers(x)
    
    
    if __name__ == '__main__':
        # your model instantiation
        Arows = 2
        Acols = 2
        mlp = SimpleMLP(Arows, Acols)
    
        # your vector computation
        v = torch.ones((6,))
        v_tensors = []
        idx = 0
        #this code "reshapes" the v vector as needed
        for i, param in enumerate(mlp.parameters()):
            numel = param.numel()
            v_tensors.append(torch.reshape(torch.tensor(v[idx:idx+numel]), param.shape))
            idx += numel
        reshaped_v = tuple(v_tensors)
    
        #make model's parameters functional
        params, names = make_functional(mlp)
        params = tuple(p.detach().requires_grad_() for p in params)
    
        #compute hvp
        soln = torch.autograd.functional.vhp(forward, params, reshaped_v, strict=True)
        print(soln)
        
        
    

    【讨论】:

    • 感谢您找到此内容。但是,在这种情况下,为什么 jacobian 会是 None 呢?
    • 经过一番搜索,我找到了this。它适用于您的情况。
    • 您能否发布一个完整的工作示例,或者至少发布应该更改的行以及如何在帖子中更改它们?如果是这样,我可以接受您的帖子并提供赏金。
    • 更新添加了完整的工作示例。编码愉快!
    【解决方案2】:

    你试过用双精度数而不是浮点数吗?我自己做了一些测试,与双精度相比,使用 32 位浮点数(大约 1e-5)进行反向传播时显示出相当大的错误。

    【讨论】:

    • 这不是问题,因为如果我不使用模型,我可以使完全相同的系统正确反向传播,而只需手动编码 Ax + b。此外,系统太小,无法达到浮点精度的难度。
    猜你喜欢
    • 2020-11-01
    • 2021-03-25
    • 1970-01-01
    • 2022-12-06
    • 2020-11-23
    • 2020-06-06
    • 2019-09-20
    • 2019-12-09
    • 2020-12-16
    相关资源
    最近更新 更多