diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1f690b7111ee2..13aebc605af74 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -169,4 +169,5 @@ class RocmPlatform(Platform): device: Optional[torch.types.Device] = None ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.max_memory_allocated(device) + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( + device)[0]