【发布时间】:2021-02-12 19:21:20
【问题描述】:
在运行 nn.Sequential 时,我包含一个类模块列表(这将是神经网络的层)。运行 nn.Sequential 时,它调用模块的转发函数。然而,每个类模块也有一个我想在 nn.Sequential 运行时访问的函数。运行 nn.Sequential 时如何访问和运行该函数?
【问题讨论】:
标签: pytorch
在运行 nn.Sequential 时,我包含一个类模块列表(这将是神经网络的层)。运行 nn.Sequential 时,它调用模块的转发函数。然而,每个类模块也有一个我想在 nn.Sequential 运行时访问的函数。运行 nn.Sequential 时如何访问和运行该函数?
【问题讨论】:
标签: pytorch
您可以为此使用 hook。让我们考虑以下在 VGG16 上演示的示例:
这是网络架构:
假设我们要监控 features Sequential(您在上面看到的 Conv2d 层)中第 (2) 层的输入和输出。
为此,我们注册了一个名为 my_hook 的前向钩子,它将在任何前向传递中被调用:
import torch
from torchvision.models import vgg16
def my_hook(self, input, output):
print('my_hook\'s output')
print('input: ', input)
print('output: ', output)
# Sample net:
net = vgg16()
#Register forward hook:
net.features[2].register_forward_hook(my_hook)
# Test:
img = torch.randn(1,3,512,512)
out = net(img) # Will trigger my_hook and the data you are looking for will be printed
【讨论】: