【发布时间】:2020-12-07 04:31:07
【问题描述】:
我正在尝试使用来自外部函数的信息来决定要返回哪些数据。在这里,我添加了一个简化的代码来演示这个问题。当我使用num_workers = 0 时,我得到了所需的行为(3 个时期后的输出为 18)。但是,当我增加 num_workers 的值时,每个 epoch 之后的输出都是相同的。并且全局变量保持不变。
from torch.utils.data import Dataset, DataLoader
x = 6
def getx():
global x
x+=1
print("x: ", x)
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
num_workers=0 时的最终输出为 18 符合预期。但是num_workers>0时,x保持不变(最终输出为6)。
如何使用num_workers>0获得与num_workers=0类似的行为(即如何确保dataloader的__getitem__函数改变全局变量x的值)?
【问题讨论】:
标签: python multiprocessing dataset pytorch dataloader