【发布时间】:2021-12-09 18:57:12
【问题描述】:
我想用多个 gpu 训练模型。我正在使用以下代码
model = load_model(path)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)
model.to(device)
它运行良好,只是 DataParallel 不包含原始模型中的函数,有没有办法解决它?谢谢
【问题讨论】:
-
"DataParallel 不包含原始模型中的函数",你到底是什么意思?
-
@Ivan 我对 ML 很陌生,它是 VQGan 模型,它包含 VectorQuantizer 作为 self.quantize 属性,当我们执行“model = nn.DataParallel(model)”时它丢失了
-
您好,既然有 pytorch-lightning 的标签,您想查看那里的多 GPU 文档吗? pytorch-lightning.readthedocs.io/en/stable/advanced/…
-
@NanoBit 谢谢,是的模型继承了pl.LightningModule
-
请澄清您的具体问题或提供其他详细信息以准确突出您的需求。正如目前所写的那样,很难准确地说出你在问什么。
标签: pytorch torch pytorch-lightning