【问题标题】:How to write a contextmanager to throw and catch errors如何编写上下文管理器来抛出和捕获错误
【发布时间】:2021-03-02 04:07:19
【问题描述】:

我想在我的代码中多次捕获运行时错误CUDA out of memory。我这样做是为了以较小的批量重新运行整个训练工作流程。最好的方法是什么?

我目前正在这样做:

try:
    result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
    if str(e).startswith('CUDA out of memory.') and batch_size > 10:
        raise CudaOutOfMemory(e)
    else:
        raise e

然后我在我的主函数之外捕获错误CudaOutOfMemory

但是,这是一段相当长的代码,我需要重复很多次。有没有办法为此创建一个上下文管理器?

这样我可以运行:

with catch_cuda_out_of_mem_error:
  result = model(input)

编辑: 我想创建一个上下文管理器而不是一个函数,因为我想包装“try, except”的函数并不总是相同的。在我的工作流程中,我有许多使用大量 GPU 内存的函数,我想在其中任何一个函数中发现这个错误。

【问题讨论】:

    标签: python pytorch contextmanager


    【解决方案1】:

    使用上下文管理器是关于正确获取和释放资源。在这里,您实际上并没有任何要获取和释放的资源,因此我认为上下文管理器不合适。只使用一个函数怎么样?

    def try_compute_model(input):
        try:
            return model(input)
        # if the GPU runs out of memory, start the experiment again with a smaller batch size
        except RuntimeError as e:
            if str(e).startswith('CUDA out of memory.') and batch_size > 10:
                raise CudaOutOfMemory(e)
            else:
                raise e
    

    然后像这样使用它

    result = try_compute_model(input)
    

    【讨论】:

    • Python 开发者没有那么严格。例如。 contexlib.suppress 管理器不获取/释放资源。
    • 嗨@mCoding 感谢您的回复!我想创建一个上下文管理器而不是一个函数,因为我想包装“try, except”的函数并不总是相同的。在我的工作流程中,我有许多使用大量 GPU 内存的函数,我想在其中任何一个函数中捕获这个错误。抱歉,如果我的问题在这个意义上具有误导性
    【解决方案2】:

    受这篇文章的启发:General decorator to wrap try except in python? 我找到了问题的答案:

    import torch
    from contextlib import contextmanager
    
    
    class CudaOutOfMemory(Exception):
        pass
    
    
    @contextmanager
    def catching_cuda_out_of_memory():
        """
        Context that throws CudaOutOfMemory error if GPU is out of memory.
        """
        try:
            yield
        except RuntimeError as e:
            if str(e).startswith('CUDA out of memory.'):
                raise CudaOutOfMemory(e)
            else:
                raise e
    
    
    def oom():
        x = torch.randn(100, 10000, device=1)
        for _ in range(100):
            l = torch.nn.Linear(10000, 10000)
            l.to(1)
            x = l(x)
    
    
    try:
        with catching_cuda_out_of_memory():
            oom()
    except CudaOutOfMemory:
        print('GOTCHA!')
    

    【讨论】:

      猜你喜欢
      • 2018-01-21
      • 1970-01-01
      • 2011-01-25
      • 2023-01-28
      • 1970-01-01
      • 2021-12-12
      • 1970-01-01
      • 2023-03-28
      • 2022-07-26
      相关资源
      最近更新 更多