【发布时间】:2021-12-04 09:41:33
【问题描述】:
我正在尝试各种深度网络,我一直想知道参数是如何参与的。我使用的是 pytorch summary,但是我注意到如果我在前向传递中多次使用同一个模块,则其相关参数会被计算多次。
一个例子是这样的:
class Net(nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.lin = nn.Linear(3,3)
def forward(self,x):
x = self.lin(x)
x = self.lin(x)
x = self.lin(x)
x = self.lin(x)
return x
net = Net()
from torchsummary import summary
summary(net.to(device),(1,3))
你得到 48 个总参数,即 12*4。 12,在这种情况下,实际上是网络可训练参数的数量。
因此,我的问题是,有没有办法让 Pytorch summary 打印出模型的“单个”可训练参数的数量?
否则,我知道使用这样的脚本
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"The network has {params} trainable parameters")
为了得到想要的结果,但我喜欢 pytorch summary 的工作原理。
【问题讨论】:
标签: deep-learning pytorch