diff --git a/vllm/utils.py b/vllm/utils.py index d5d8d4efa95c0..7ec9e3289c971 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -40,11 +40,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int: return int(max_shared_mem) -def get_gpu_memory(gpu: int = 0) -> int: - """Returns the total memory of the GPU in bytes.""" - return torch.cuda.get_device_properties(gpu).total_memory - - def get_cpu_memory() -> int: """Returns the total CPU memory of the node in bytes.""" return psutil.virtual_memory().total diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6f5e16f0011f6..e32949115178b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,6 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner -from vllm.utils import get_gpu_memory class Worker: @@ -81,7 +80,6 @@ class Worker: # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -90,8 +88,9 @@ class Worker: # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.cuda.synchronize() - peak_memory = torch.cuda.max_memory_allocated() - total_gpu_memory = get_gpu_memory() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + cache_block_size = CacheEngine.get_cache_block_size( block_size, self.model_config, self.parallel_config) num_gpu_blocks = int(