diff --git a/pytorch_memlab/line_profiler/line_profiler.py b/pytorch_memlab/line_profiler/line_profiler.py index 7615ef1..b7528b8 100644 --- a/pytorch_memlab/line_profiler/line_profiler.py +++ b/pytorch_memlab/line_profiler/line_profiler.py @@ -88,8 +88,9 @@ def enable(self): try: torch.cuda.empty_cache() self._reset_cuda_stats() - # Pytorch-1.7.0 raises AttributeError while <1.6.0 raises AssertionError - except (AssertionError, AttributeError) as error: + # What error is raised depends on PyTorch version: + # latest raises RuntimeError, 1.7.0 raises AttributeError, <1.7.0 raises AssertionError + except (AssertionError, AttributeError, RuntimeError) as error: print('Could not reset CUDA stats and cache: ' + str(error)) self.register_callback()