【发布时间】:2019-09-19 01:25:08
【问题描述】:
我在 pytorch 中使用 GroupNorm 而不是 BatchNorm 并保持所有其他(网络架构)不变。它表明在 Imagenet 数据集中,使用 resnet50 架构,GroupNorm 比 BatchNorm 慢 40%,消耗的 GPU 内存比 BatchNorm 多 33%。我真的很困惑,因为 GroupNorm 不需要比 BatchNorm 更多的计算。详情如下。
关于Group Normalization的详细内容,可以看这篇论文:https://arxiv.org/pdf/1803.08494.pdf
对于 BatchNorm,一个 minibatch 消耗 12.8 秒,GPU 内存为 7.51GB;
对于 GroupNorm,一个 minibatch 消耗 17.9 秒,GPU 内存为 10.02GB。
我使用以下代码将所有 BatchNorm 图层转换为 GroupNorm 图层。
def convert_bn_model_to_gn(module, num_groups=16):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with :class:`torch.nn.GroupNorm`.
Args:
module: your network module
num_groups: num_groups of GN
"""
mod = module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
mod = nn.GroupNorm(num_groups, module.num_features,
eps=module.eps, affine=module.affine)
# mod = nn.modules.linear.Identity()
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_bn_model_to_gn(
child, num_groups=num_groups))
del module
return mod
【问题讨论】:
标签: pytorch