【问题标题】:Pytorch DataParallel with custom modelPytorch DataParallel 与自定义模型
【发布时间】:2021-12-09 18:57:12
【问题描述】:

我想用多个 gpu 训练模型。我正在使用以下代码

model = load_model(path)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

它运行良好,只是 DataParallel 不包含原始模型中的函数,有没有办法解决它?谢谢

【问题讨论】:

  • "DataParallel 不包含原始模型中的函数",你到底是什么意思?
  • @Ivan 我对 ML 很陌生,它是 VQGan 模型,它包含 VectorQuantizer 作为 self.quantize 属性,当我们执行“model = nn.DataParallel(model)”时它丢失了
  • 您好,既然有 pytorch-lightning 的标签,您想查看那里的多 GPU 文档吗? pytorch-lightning.readthedocs.io/en/stable/advanced/…
  • @NanoBit 谢谢,是的模型继承了pl.LightningModule
  • 请澄清您的具体问题或提供其他详细信息以准确突出您的需求。正如目前所写的那样,很难准确地说出你在问什么。

标签: pytorch torch pytorch-lightning


【解决方案1】:

传递给nn.DataParallelnn.Module 最终将被类包装以处理数据并行性。您仍然可以使用 module 属性访问您的模型。

>>> p_model = nn.DataParallel(model)
>>> p_model.module # <- model

例如,要访问底层模型的 quantize 属性,您可以:

>>> p_model.module.quantize

【讨论】:

    猜你喜欢
    • 2021-09-17
    • 2021-09-21
    • 2019-09-20
    • 2020-07-03
    • 2021-09-08
    • 1970-01-01
    • 2020-03-07
    • 1970-01-01
    • 2019-04-26
    相关资源
    最近更新 更多