diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 035766289aebd..a1ca420bef149 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -169,4 +169,5 @@ def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.max_memory_allocated(device) + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( + device)[0]