【问题标题】:Accessing functions in the class modules of nn.Sequential访问 nn.Sequential 的类模块中的函数
【发布时间】:2021-02-12 19:21:20
【问题描述】:

在运行 nn.Sequential 时,我包含一个类模块列表(这将是神经网络的层)。运行 nn.Sequential 时,它调用模块的转发函数。然而,每个类模块也有一个我想在 nn.Sequential 运行时访问的函数。运行 nn.Sequential 时如何访问和运行该函数?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    您可以为此使用 hook。让我们考虑以下在 VGG16 上演示的示例:

    这是网络架构:

    假设我们要监控 features Sequential(您在上面看到的 Conv2d 层)中第 (2) 层的输入和输出。 为此,我们注册了一个名为 my_hook 的前向钩子,它将在任何前向传递中被调用:

    import torch
    from torchvision.models import vgg16
    
    def my_hook(self, input, output):
        print('my_hook\'s output')
        print('input: ', input)
        print('output: ', output)
    
    # Sample net:
    net = vgg16()
    
    #Register forward hook:
    net.features[2].register_forward_hook(my_hook)
    
    # Test:
    img = torch.randn(1,3,512,512)
    out = net(img) # Will trigger my_hook and the data you are looking for will be printed
    
    

    【讨论】:

      猜你喜欢
      • 2017-11-01
      • 2015-03-29
      • 2012-01-18
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2014-07-20
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多