【发布时间】:2021-09-29 04:39:11
【问题描述】:
模型保存方式:
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
save_model = f'./model'
Path(save_model).mkdir(parents=True, exist_ok=True)
net = Net(args) # .to(device)
model_name = f"{save_model}/net.pt"
torch.save(net.state_dict(), model_name) #
model = Model(net, args).to(device)
model_name = f"{save_model}/model.pt"
if torch.cuda.device_count() > 1:
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank)
model.module.fit(tr_data, val_data, args)
torch.save(maml, model_name)
我尝试加载模型:
save_model = f'./model'
net = Net(args) # .to(device)
model_name = f"{save_model}/net.pt"
net.load_state_dict(
torch.load(model_name, map_location=torch.torch.device("cpu")))
maml = Model(net, args).to(device)
model_name = f"{save_model}/model.pt"
maml = torch.load(model_name, map_location=torch.torch.device(
"cuda" if torch.cuda.is_available() else "cpu")) # .load_state_dict
“net”可以加载成功,但是加载“model”时报错:
文件“D:\Research\Traffic Prediction\maml\venv\lib\site-packages\torch\serialization.py”,第 607 行,在加载返回 _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
文件“D:\Research\Traffic Prediction\maml\venv\lib\site-packages\torch\serialization.py”,第 882 行,在 _load result = unpickler.load()
TypeError: () 缺少 1 个必需的位置参数:'ddp_join_throw_on_early_termination'
任何意见将不胜感激。
【问题讨论】:
标签: pytorch