[V1][PP] Fix memory profiling in PP (#13315)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-14 20:17:25 -08:00 committed by GitHub
parent 6a854c7a2b
commit 0c73026844
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1158,11 +1158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
else:
logits = None
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()