[V1] Move OOM check into sampler run (#14728)

Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Roger Wang 2025-03-13 20:40:23 -07:00 committed by GitHub
parent 2a602b055a
commit ad19c8a003
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 18 deletions

View File

@ -1288,9 +1288,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)
try:
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return sampler_output
def profile_run(self) -> None:

View File

@ -221,21 +221,11 @@ class Worker(WorkerBase):
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
try:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler. "
"Please try lowering `gpu_memory_utilization` when "
"initializing the engine.") from None
else:
raise e
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.