【发布时间】:2021-03-02 04:07:19
【问题描述】:
我想在我的代码中多次捕获运行时错误CUDA out of memory。我这样做是为了以较小的批量重新运行整个训练工作流程。最好的方法是什么?
我目前正在这样做:
try:
result = model(input)
# if the GPU runs out of memory, start the experiment again with a smaller batch size
except RuntimeError as e:
if str(e).startswith('CUDA out of memory.') and batch_size > 10:
raise CudaOutOfMemory(e)
else:
raise e
然后我在我的主函数之外捕获错误CudaOutOfMemory。
但是,这是一段相当长的代码,我需要重复很多次。有没有办法为此创建一个上下文管理器?
这样我可以运行:
with catch_cuda_out_of_mem_error:
result = model(input)
编辑: 我想创建一个上下文管理器而不是一个函数,因为我想包装“try, except”的函数并不总是相同的。在我的工作流程中,我有许多使用大量 GPU 内存的函数,我想在其中任何一个函数中发现这个错误。
【问题讨论】:
标签: python pytorch contextmanager