【问题标题】:How to add parameters in module class in pytorch custom model?如何在pytorch自定义模型的模块类中添加参数?
【发布时间】:2020-04-01 16:37:02
【问题描述】:

我试图找到答案,但我找不到。

我使用 pytorch 制作了一个自定义深度学习模型。例如,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.nn_layers = nn.ModuleList()
        self.layer = nn.Linear(2,3).double()
        torch.nn.init.xavier_normal_(self.layer.weight)

        self.bias = torch.nn.Parameter(torch.randn(3))

        self.nn_layers.append(self.layer)

    def forward(self, x):
        activation = torch.tanh
        output = activation(self.layer(x)) + self.bias

        return output

如果我打印

model = Net()
print(list(model.parameters()))

它不包含model.bias,所以 optimizer = optimizer.Adam(model.parameters()) 不更新 model.bias。 我怎么能通过这个? 谢谢!

【问题讨论】:

  • 如果您是子类化模块,请参阅此处发布的解决方案。 discuss.pytorch.org/t/…。使用 print(list(model.parameters())) 进行测试。

标签: deep-learning pytorch


【解决方案1】:

你需要register你的参数:

self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(3)))

【讨论】:

  • 谢谢!我认为 ``` self.register_parameter(name='bias', param=torch.nn.Parameter(self.bias)) ``` 不起作用。有没有办法摆脱“名称”部分?我认为这在保存和加载模型时会很麻烦。
  • @CSH 通过在注册参数时使用name='bias' 隐式创建self.bias 以在Net 类中使用。您确实不需要需要显式分配self.bias 以使其存在。
  • 天啊。多么方便的功能!感谢您的帮助!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-11-08
  • 2018-11-29
  • 2022-01-18
  • 2023-04-06
相关资源
最近更新 更多