【问题标题】:How to use PyTorch multiprocessing?如何使用 PyTorch 多处理?
【发布时间】:2018-07-27 02:36:09
【问题描述】:

我正在尝试在pytorch 中使用python 的多处理Pool 方法来处理图像。代码如下:

from multiprocessing import Process, Pool
from torch.autograd import Variable
import numpy as np
from scipy.ndimage import zoom

def get_pred(args):

  img = args[0]
  scale = args[1]
  scales = args[2]
  img_scale = zoom(img.numpy(),
                     (1., 1., scale, scale),
                     order=1,
                     prefilter=False,
                     mode='nearest')

  # feed input data
  input_img = Variable(torch.from_numpy(img_scale),
                     volatile=True).cuda()
  return input_img

scales = [1,2,3,4,5]
scale_list = []
for scale in scales: 
    scale_list.append([img,scale,scales])
multi_pool = Pool(processes=5)
predictions = multi_pool.map(get_pred,scale_list)
multi_pool.close() 
multi_pool.join()

我收到此错误:

`RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

` 在这一行:

predictions = multi_pool.map(get_pred,scale_list)

谁能告诉我我做错了什么?

【问题讨论】:

    标签: python computer-vision multiprocessing pytorch


    【解决方案1】:

    pytorch documentation 中所述,处理多处理的最佳实践是使用torch.multiprocessing 而不是multiprocessing

    请注意,仅 Python 3 支持在进程之间共享 CUDA 张量,使用 spawnforkserver 作为启动方法。

    在不接触您的代码的情况下,您遇到的错误的解决方法是替换

    from multiprocessing import Process, Pool
    

    与:

    from torch.multiprocessing import Pool, Process, set_start_method
    try:
         set_start_method('spawn')
    except RuntimeError:
        pass
    

    【讨论】:

    • 有时甚至torch.multiprocessing.set_start_method('spawn', force=True)
    • 在这种情况下,请确保您的主循环由if __name__ == '__main__': 分隔,因为全局语句将在生成时执行
    • 这解决了一个问题,但又引入了另一个问题,TypeError: 'NoneType' object is not callable。有人看过吗?
    【解决方案2】:

    我建议您阅读多处理模块的文档,尤其是 this section。您必须通过调用set_start_method 来更改创建子流程的方式。摘自那些引用的文档:

    import multiprocessing as mp
    
    def foo(q):
        q.put('hello')
    
    if __name__ == '__main__':
        mp.set_start_method('spawn')
        q = mp.Queue()
        p = mp.Process(target=foo, args=(q,))
        p.start()
        print(q.get())
        p.join()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-05-28
      • 2020-01-28
      • 2021-02-12
      • 2022-12-06
      • 2020-09-04
      • 1970-01-01
      • 2018-12-08
      • 1970-01-01
      相关资源
      最近更新 更多