【问题标题】:How do I fix the Dataset to return desired output (pytorch)如何修复数据集以返回所需的输出(pytorch)
【发布时间】:2020-12-07 04:31:07
【问题描述】:

我正在尝试使用来自外部函数的信息来决定要返回哪些数据。在这里,我添加了一个简化的代码来演示这个问题。当我使用num_workers = 0 时,我得到了所需的行为(3 个时期后的输出为 18)。但是,当我增加 num_workers 的值时,每个 epoch 之后的输出都是相同的。并且全局变量保持不变。

from torch.utils.data import Dataset, DataLoader

x = 6
def getx():
    global x
    x+=1
    print("x: ", x)
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

num_workers=0 时的最终输出为 18 符合预期。但是num_workers>0时,x保持不变(最终输出为6)。

如何使用num_workers>0获得与num_workers=0类似的行为(即如何确保dataloader的__getitem__函数改变全局变量x的值)?

【问题讨论】:

    标签: python multiprocessing dataset pytorch dataloader


    【解决方案1】:

    原因在于 python 中多处理的基本性质。设置num_workers 意味着您的DataLoader 创建了该数量的子进程。每个子进程实际上是一个单独的 Python 实例,具有自己的全局状态,并且不知道其他进程中发生了什么。

    在 python 的多处理中,一个典型的解决方案是使用Manager。但是,由于您的多处理是通过 DataLoader 提供的,因此您无法在其中进行处理。

    幸运的是,可以做些别的事情。 DataLoader 实际上依赖于 torch.multiprocessing,这反过来又允许在进程之间共享张量,只要它们在共享内存中。

    所以你可以做的是,简单地使用 x 作为共享张量。

    from torch.utils.data import Dataset, DataLoader
    import torch 
    
    x = torch.tensor([6])
    x.share_memory_()
    
    def getx():
        global x
        x+=1
        print("x: ", x.item())
        return x
    
    class MyDataset(Dataset):
        def __init__(self):
            pass
    
        def __getitem__(self, index):
            global x
            x = getx()
            return x
        
        def __len__(self):
            return 3
    
    dataset = MyDataset()
    loader = DataLoader(
        dataset,
        num_workers=2,
        shuffle=False
    )
    
    for epoch in range(4):
        for idx, data in enumerate(loader):
            print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
    

    输出:

    x:  7
    x:  8
    x:  9
    Epoch 0, idx 0, val: tensor([[7]])
    Epoch 0, idx 1, val: tensor([[8]])
    Epoch 0, idx 2, val: tensor([[9]])
    x:  10
    x:  11
    x:  12
    Epoch 1, idx 0, val: tensor([[10]])
    Epoch 1, idx 1, val: tensor([[12]])
    Epoch 1, idx 2, val: tensor([[12]])
    x:  13
    x:  14
    x:  15
    Epoch 2, idx 0, val: tensor([[13]])
    Epoch 2, idx 1, val: tensor([[15]])
    Epoch 2, idx 2, val: tensor([[14]])
    x:  16
    x:  17
    x:  18
    Epoch 3, idx 0, val: tensor([[16]])
    Epoch 3, idx 1, val: tensor([[18]])
    Epoch 3, idx 2, val: tensor([[17]])
    

    虽然这可行,但并不完美。查看 epoch 1,注意到有 2 个 12,而不是 11 和 12。这意味着两个独立的进程在执行 print 之前已经执行了行 x+=1。这是不可避免的,因为并行进程正在共享内存上工作。

    如果您熟悉操作系统概念,则可以进一步实现某种semaphore,并根据需要使用一个额外的变量来控制对 x 的访问——但这超出了问题的范围,不再赘述。

    【讨论】:

      猜你喜欢
      • 2022-07-30
      • 1970-01-01
      • 2023-02-01
      • 2022-07-17
      • 2016-12-20
      • 2019-07-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多