【发布时间】:2023-04-07 05:49:01
【问题描述】:
我有一个torchvision.models.ResNet 的实例,我有我的类CondBatchNorm2d,它是一个类似于BatchNorm2d 的模块,但是forward 方法接受一个额外的输入y,它不是来自上一层,因为它是整个网络的输入:
def forward(self, x, y=None):
...
我知道如何用CondBatchNorm2d 的实例替换每个BatchNorm2d 实例,但我不确定如何编写自己的转发方法来包含中间CondBatchNorm2d 层的新输入。我应该在 resnet 孩子上迭代还是有更合适的方法?
【问题讨论】:
标签: python pytorch torchvision