【问题标题】:Deriving the structure of a pytorch network推导pytorch网络的结构
【发布时间】:2020-02-03 18:32:03
【问题描述】:

对于我的用例,我需要能够获取一个 pytorch 模块并解释模块中的层序列,以便我可以以某种文件格式在层之间创建“连接”。现在假设我有一个简单的模块,如下所示

class mymodel(nn.Module):
    def __init__(self, input_channels):
        super(mymodel, self).__init__()
        self.fc = nn.Linear(input_channels, input_channels)
    def forward(self, x):
        out = self.fc(x)
        out += x
        return out


if __name__ == "__main__":
    net = mymodel(5)

    for mod in net.modules():
        print(mod) 

这里的输出结果:

mymodel(
  (fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

如您所见,有关加号等于操作或加号操作的信息未被捕获,因为它不是 forward 函数中的 nnmodule。我的目标是能够从 pytorch 模块对象创建一个图形连接,以便在 json 中说这样的话:

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

输入张量名称是任意的,但它捕捉到了图的本质和层之间的连接。

我的问题是,有没有办法从 pytorch 对象中提取信息?我正在考虑使用 .modules() 但后来意识到手写操作不是以这种方式作为模块捕获的。我想如果一切都是 nn.module 那么 .modules() 可能会给我网络层的安排。在这里寻求帮助。我希望能够知道张量之间的联系以创建上述格式。

【问题讨论】:

    标签: python neural-network pytorch tensor


    【解决方案1】:

    您要查找的信息不是存储在nn.Module 中,而是存储在输出张量的grad_fn 属性中:

    model = mymodel(channels)
    pred = model(torch.rand((1, channels))
    pred.grad_fn  # all the information is in the computation graph of the output tensor
    

    提取这些信息并非易事。你可能想看看torchviz 包,它从grad_fn 信息中绘制了一个漂亮的图表。

    【讨论】:

      猜你喜欢
      • 2019-03-08
      • 2021-12-30
      • 2018-12-16
      • 2018-12-11
      • 1970-01-01
      • 2011-03-31
      • 2010-10-03
      • 2019-01-17
      相关资源
      最近更新 更多