[V1][Core] Fix memory issue with logits & sampling (#13721)

This commit is contained in:
Roger Wang 2025-02-24 06:10:06 -08:00 committed by GitHub
parent f90a375593
commit 437b76ff59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 29 deletions

View File

@ -1179,6 +1179,43 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return hidden_states
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
spec_token_ids=None,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)
return sampler_output
def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
@ -1306,38 +1343,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dummy_kv_caches)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
spec_token_ids=None,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
sampler_output = self._dummy_sampler_run(hidden_states)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits, sampler_output, dummy_metadata
del hidden_states, sampler_output
self.encoder_cache.clear()
gc.collect()

View File

@ -211,6 +211,16 @@ class Worker(WorkerBase):
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=self.scheduler_config.max_num_seqs))
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)