From 2d8476e465bb861c1e7d1e65c800b725282271c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 7 Jun 2025 13:34:51 -0400 Subject: [PATCH] [BugFix][V1] Fix memory profiling bug (#18974) Signed-off-by: luka --- tests/models/test_initialization.py | 2 ++ tests/v1/sample/test_logprobs.py | 13 ++++--- vllm/v1/worker/gpu_worker.py | 54 +++++++++++++++++++++++------ 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 98a58d01e2a18..54e8cd597bfc4 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -86,6 +86,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): } if model_info.speculative_model else None, trust_remote_code=model_info.trust_remote_code, max_model_len=model_info.max_model_len, + # these tests seem to produce leftover memory + gpu_memory_utilization=0.80, load_format="dummy", hf_overrides=hf_overrides, ) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 612eca116f231..69180e6e5db49 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -42,7 +42,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: #TODO: enable this once we support it for # prompt logprobs. enable_prefix_caching=request.param, - gpu_memory_utilization=0.5, + gpu_memory_utilization=0.4, # up to 2 alive concurrently ) as vllm_model: yield vllm_model @@ -343,10 +343,13 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - runner = VllmRunner("facebook/opt-125m", - max_logprobs=1, - enable_prefix_caching=False, - max_model_len=256) + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.15, + max_model_len=256) vllm_sampling_params = SamplingParams(logprobs=1) # should pass runner.generate(["Hello world"], sampling_params=vllm_sampling_params) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3bf3b2221a447..1dfccc9b31bd5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -130,7 +130,20 @@ class Worker(WorkerBase): _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] + self.init_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + requested_memory = (total_gpu_memory * + self.cache_config.gpu_memory_utilization) + if self.init_gpu_memory < requested_memory: + GiB = lambda b: round(b / GiB_bytes, 2) + raise ValueError( + f"Free memory on device ({GiB(self.init_gpu_memory)}/" + f"{GiB(total_gpu_memory)} GiB) on startup is less than " + f"desired GPU memory utilization " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(requested_memory)} GiB). Decrease GPU memory " + f"utilization or reduce GPU memory used by other processes." + ) + else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -190,28 +203,47 @@ class Worker(WorkerBase): # GPU did not change their memory usage during the profiling. assert self.init_gpu_memory > free_gpu_memory, ( "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + f"Initial free memory {self.init_gpu_memory/GiB_bytes} GiB, " + f"current free memory {free_gpu_memory/GiB_bytes} GiB. " + f"This happens when the GPU memory was not properly cleaned up " + f"before initializing the vLLM instance.") # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + peak_torch_memory = torch.cuda.memory_stats( + )["allocated_bytes.all.peak"] # Check for any memory left around that may have been allocated on the # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass + # GB during a forward pass. torch.cuda.empty_cache() torch_allocated_bytes = torch.cuda.memory_stats( )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations + + # Reset after emptying torch cache + free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Total forward allocation (current) is equal to the diff in free memory + fwd_alloc_bytes = self.init_gpu_memory - free_gpu_memory + # We assume current non-torch allocation is equal to peak + non_torch_alloc_bytes = max(0, fwd_alloc_bytes - torch_allocated_bytes) + # Total forward allocation (peak) is peak torch + non-torch + peak_memory = peak_torch_memory + non_torch_alloc_bytes + available_kv_cache_memory = ( total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) + GiB = lambda b: b / GiB_bytes + logger.debug( + "Initial free memory: %.2f GiB, free memory: %.2f GiB, " + "total GPU memory: %.2f GiB", GiB(self.init_gpu_memory), + GiB(free_gpu_memory), GiB(total_gpu_memory)) + logger.debug( + "Peak torch memory: %.2f GiB, non-torch forward-pass memory: " + "%.2f GiB, available KVCache memory: %.2f GiB", + GiB(peak_torch_memory), GiB(non_torch_alloc_bytes), + GiB(available_kv_cache_memory)) + return int(available_kv_cache_memory) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: