From 0c73026844e8a2d3ff017bf0e802b34bf8263aa0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 14 Feb 2025 20:17:25 -0800 Subject: [PATCH] [V1][PP] Fix memory profiling in PP (#13315) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e90b76dcdd9a..821c9e138028 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1158,11 +1158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) - if not get_pp_group().is_last_rank: - return hidden_states - hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) - # TODO(woosuk): Consider the memory usage of the sampler. + if get_pp_group().is_last_rank: + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + else: + logits = None torch.cuda.synchronize() del hidden_states, logits self.encoder_cache.clear()