【发布时间】:2022-01-15 17:01:44
【问题描述】:
我想在 pytorch 中建立一个仅扩展数据的网络。
这意味着如果我的输入是 [1, 2] 并且我的输出是 [2, 6]。
那么线性层将如下所示:
[ [ 2, 0],
[ 0, 3] ].
我有这个网络是用 pytorch 写的:
class ScalingNetwork(nn.Module):
def __init__(self, input_shape, output_shape):
super().__init__()
self.linear_layer = nn.Linear(in_features=input_shape, out_features=output_shape)
self.mask = torch.diag(torch.ones(input_shape))
self.linear_layer.weight.data = self.linear_layer.weight * self.mask
self.linear_layer.weight.requires_grad = True
def get_tranformation_matrix(self):
return self.linear_layer.weight
def forward(self, X):
X = self.linear_layer(X)
return X
但在训练结束时,我的 self.linear 不是对角线。 我做错了什么?
【问题讨论】:
标签: python deep-learning neural-network pytorch