【发布时间】:2020-11-09 13:11:25
【问题描述】:
我想在分布式数据并行管理的进程之一中使用日志记录。但是,日志在以下代码中不打印任何内容(代码源自this tutorial):
#!/usr/bin/python
import os, logging
# logging.basicConfig(level=logging.DEBUG)
import torch
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize the process group.
dist.init_process_group('NCCL', rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
setup(rank, world_size)
if rank == 0:
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
logger.info(f'Running DPP on rank={rank}.')
# Create model and move it to GPU.
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # optimizer takes DDP model.
optimizer.zero_grad()
inputs = torch.randn(20, 10) # .to(rank)
outputs = ddp_model(inputs)
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_func, world_size):
mp.spawn(
demo_func,
args=(world_size,),
nprocs=world_size,
join=True
)
def main():
run_demo(demo_basic, 4)
if __name__ == "__main__":
main()
但是,当我取消注释第 4 行时,日志记录工作。请问一下这个bug的原因和解决方法?
【问题讨论】:
-
除非我记错了,否则第 6 行是空的 - 我猜你的意思是第 4 行?
-
@Xtrem532 感谢您的指出。现在更正。
标签: python-3.x pytorch python-multiprocessing python-logging