需要定义一个回调函数:

def get_features_hook(self, input, output):
    print("hook", output.data.cpu().numpy().shape)

然后对需要查看的层注册钩子:

handle = self.model.fc_loc[2].register_forward_hook(get_features_hook)

在查看完后移除钩子:

handle.remove()

相关文章: