【问题标题】:Pytorch nn.parallel.DistributedDataParallel model loadPytorch nn.parallel.DistributedDataParallel 模型加载
【发布时间】:2021-09-29 04:39:11
【问题描述】:

模型保存方式:

torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
save_model = f'./model'
Path(save_model).mkdir(parents=True, exist_ok=True)
net = Net(args)  # .to(device)
model_name = f"{save_model}/net.pt"
torch.save(net.state_dict(), model_name)   #
model = Model(net, args).to(device)
model_name = f"{save_model}/model.pt"

if torch.cuda.device_count() > 1:
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
                                                output_device=local_rank)
model.module.fit(tr_data, val_data, args)
torch.save(maml, model_name) 

我尝试加载模型:

save_model = f'./model'
net = Net(args)  # .to(device)
model_name = f"{save_model}/net.pt"
net.load_state_dict(
    torch.load(model_name, map_location=torch.torch.device("cpu")))
maml = Model(net, args).to(device)
model_name = f"{save_model}/model.pt"
maml = torch.load(model_name, map_location=torch.torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"))  # .load_state_dict

“net”可以加载成功,但是加载“model”时报错:

文件“D:\Research\Traffic Prediction\maml\venv\lib\site-packages\torch\serialization.py”,第 607 行,在加载返回 _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)

文件“D:\Research\Traffic Prediction\maml\venv\lib\site-packages\torch\serialization.py”,第 882 行,在 _load result = unpickler.load()

TypeError: () 缺少 1 个必需的位置参数:'ddp_join_throw_on_early_termination'

任何意见将不胜感激。

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    通常,使用 PyTorch 的 DistributedDataParallel,所有节点都保持相同的模型(因为它在反向传播期间是“同步的”)。

    保存它的最佳方法是只保存模型而不是整个 DistributedDataParallel(通常在主节点上,如果可能存在多个节点故障是一个问题):

    # or not only local_rank 0
    if local_rank == 0:
        torch.save(model.module.cpu(), path)
    

    请注意,如果您的模型包含在 DistributedDataParallel 中,那么您所使用的模型将保留在模块属性中。

    另一件事 - 将您的模型投射到 CPU,在这种情况下不需要映射(因为您可能使用多个 GPU,并且您必须在可能没有 GPU 的其他设备上适当地映射它)。

    【讨论】:

      猜你喜欢
      • 2020-05-24
      • 2021-12-15
      • 2019-09-26
      • 2021-04-05
      • 2021-09-17
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-08-14
      相关资源
      最近更新 更多