【发布时间】:2020-02-03 18:32:03
【问题描述】:
对于我的用例,我需要能够获取一个 pytorch 模块并解释模块中的层序列,以便我可以以某种文件格式在层之间创建“连接”。现在假设我有一个简单的模块,如下所示
class mymodel(nn.Module):
def __init__(self, input_channels):
super(mymodel, self).__init__()
self.fc = nn.Linear(input_channels, input_channels)
def forward(self, x):
out = self.fc(x)
out += x
return out
if __name__ == "__main__":
net = mymodel(5)
for mod in net.modules():
print(mod)
这里的输出结果:
mymodel(
(fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)
如您所见,有关加号等于操作或加号操作的信息未被捕获,因为它不是 forward 函数中的 nnmodule。我的目标是能够从 pytorch 模块对象创建一个图形连接,以便在 json 中说这样的话:
layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}
输入张量名称是任意的,但它捕捉到了图的本质和层之间的连接。
我的问题是,有没有办法从 pytorch 对象中提取信息?我正在考虑使用 .modules() 但后来意识到手写操作不是以这种方式作为模块捕获的。我想如果一切都是 nn.module 那么 .modules() 可能会给我网络层的安排。在这里寻求帮助。我希望能够知道张量之间的联系以创建上述格式。
【问题讨论】:
标签: python neural-network pytorch tensor