From 004895fb5e602daed6a83e1a7e1bcffa3c950997 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Sun, 6 Dec 2020 06:25:00 -0800 Subject: [PATCH] Catch RuntimeError in torch.cuda.reset_peak_memory_stats (#26) The AttributeError was fixed in https://github.com/pytorch/pytorch/pull/48406 --- pytorch_memlab/line_profiler/line_profiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_memlab/line_profiler/line_profiler.py b/pytorch_memlab/line_profiler/line_profiler.py index 7615ef1..b7528b8 100644 --- a/pytorch_memlab/line_profiler/line_profiler.py +++ b/pytorch_memlab/line_profiler/line_profiler.py @@ -88,8 +88,9 @@ def enable(self): try: torch.cuda.empty_cache() self._reset_cuda_stats() - # Pytorch-1.7.0 raises AttributeError while <1.6.0 raises AssertionError - except (AssertionError, AttributeError) as error: + # What error is raised depends on PyTorch version: + # latest raises RuntimeError, 1.7.0 raises AttributeError, <1.7.0 raises AssertionError + except (AssertionError, AttributeError, RuntimeError) as error: print('Could not reset CUDA stats and cache: ' + str(error)) self.register_callback()