【问题标题】:Pytorch Problem: My jupyter stuck when num_workers > 0Pytorch 问题:当 num_workers > 0 时我的 jupyter 卡住了
【发布时间】:2021-10-15 17:49:41
【问题描述】:

这是我在 PyTorch 中的代码的 sn-p,当我使用 num_workers > 0 时,我的 jupiter notebook 卡住了,我在这个问题上花了很多时间没有任何答案。我没有 GPU,我只使用 CPU。

class IndexedDataset(Dataset):

def __init__(self,data,targets, test=False):
    self.dataset = data 
    if not test:
        self.labels = targets.numpy()
        self.mask =  np.concatenate((np.zeros(NUM_LABELED), np.ones(NUM_UNLABELED)))


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]
        return image, self.labels[idx]
    
    def display(self, idx):
        plt.imshow(self.dataset[idx], cmap='gray')
        plt.show()

train_set = IndexedDataset(train_data, train_target, test = False)

test_set = IndexedDataset(test_data, test_target, test = True)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2)

test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2)

任何帮助,不胜感激。

【问题讨论】:

    标签: jupyter-notebook pytorch pytorch-dataloader


    【解决方案1】:

    由于 jupyter Notebook 不支持 python 多处理,有两个瘦库,您应该安装其中一个,如 12 中所述。

    我更喜欢在不使用任何外部库的情况下通过两种方式解决我的问题:

    1. 通过将我的文件从 .ipynb 格式转换为 .py 格式并在终端中运行它,我在 main() 函数中编写代码如下:

      ...
      ...
      
      train_set = IndexedDataset(train_data, train_target, test = False)
      
      train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=4)
      
       if `__name__ ==  '__main__'`:
           for images,label in train_loader:
               print(images.shape)
      
    2. 多处理库如下:

    try.ipynb:

    import multiprocessing as mp
    import processing as ps
    
    ...
    ...
    
    train_set = IndexedDataset(train_data, train_target, test = False)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
        
    if __name__=="__main__":
        p = mp.Pool(8)
        r = p.map(ps.getShape,train_loader) 
        print(r)
        p.close()
    

    processing.py 文件中:

    def getShape(data):
        for i in data:
            return i[0].shape
    

    【讨论】:

      【解决方案2】:

      num_workers 大于 0 时,PyTorch 使用多个进程进行数据加载。

      Jupyter 笔记本存在多处理问题。

      解决此问题的一种方法是不使用 Jupyter 笔记本 - 只需编写一个普通的 .py 文件并通过命令行运行它。

      或者尝试使用这里的建议:Jupyter notebook never finishes processing using multiprocessing (Python 3)

      【讨论】:

        猜你喜欢
        • 2022-12-22
        • 2022-01-22
        • 1970-01-01
        • 2012-04-18
        • 2018-09-28
        • 1970-01-01
        • 2019-08-17
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多