【问题标题】:Pytorch customize weightPytorch 自定义权重
【发布时间】:2020-03-15 17:29:28
【问题描述】:

我有一个网络

class Net(nn.Module)

还有两个不同的权重w0w1(将所有层的权重连接成一个向量)。现在我想优化连接w0w1 的线上的网络,这意味着权重将具有theta * w0 + (1-theta) * w1 的形式。所以现在我要优化的参数不再是权重本身,而是theta

我该如何实现呢?在 Pytorch 中,如何将参数定义为theta,并将权重设置为我想要的形式。具体来说,如果我创建一个新类

NetOnLine(nn.Module)

forward(self, X)函数应该怎么写?

【问题讨论】:

  • 查看nn.Moduleregister_parameter 方法和nn.Parameter 文档

标签: pytorch


【解决方案1】:

您可以在您的网络中将参数theta 定义为nn.Parameter。您可以像往常一样定义 forward 函数 - 通过所需的层或操作传递数据,然后返回它。

这是一个最小的例子,我训练一个“网络”来学习将张量乘以 2:

import numpy as np
import torch


class SampleNet(torch.nn.Module):
    def __init__(self):
        super(SampleNet, self).__init__()

        self.theta = torch.nn.Parameter(torch.rand(1))

    def forward(self, x):
        x = x * self.theta.expand_as(x)  # expand_as() to match sizes
        return x


train_data = np.random.rand(1000, 10)
train_data[:, 5:] = 2 * train_data[:, :5]
train_data = torch.Tensor(train_data)

sample_net = SampleNet()

optimizer = torch.optim.Adam(params=sample_net.parameters())
mse_loss = torch.nn.MSELoss()

for epoch in range(5):
    for data in train_data:
        x = data[:5]
        y = data[5:]

        optimizer.zero_grad()
        prediction = sample_net(x)
        loss = mse_loss(y, prediction)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss {loss.data.item()}")
print(f"Learned theta: {sample_net.theta.data.item()}")

打印出来

Epoch 0, Loss 0.03369491919875145
Epoch 1, Loss 0.0018534092232584953
Epoch 2, Loss 1.2343853995844256e-05
Epoch 3, Loss 2.2044337466553543e-09
Epoch 4, Loss 4.0527581290916714e-12
Learned theta: 1.999994158744812

【讨论】:

  • 但是我的网络结构非常复杂。它可能包含许多层,包括 Linear 层和 Conv2d 层,并被视为黑盒。此外,似乎这些层的所有权重都被自动视为参数。那么如何将它们变成数字并添加我自己的参数 theta。
  • 您可以直接使用layer.weight.data 访问图层权重,并根据需要将它们乘以theta。我不清楚你想如何在你的模型中使用权重和 theta,所以我很难说。也许这个讨论会帮助你访问权重:discuss.pytorch.org/t/how-to-extract-learned-weights-correctly/…
猜你喜欢
  • 1970-01-01
  • 2019-07-21
  • 2021-06-13
  • 2021-05-28
  • 2020-01-24
  • 2018-09-23
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多