mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 08:21:49 +08:00
[V1] Move OOM check into sampler run (#14728)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
2a602b055a
commit
ad19c8a003
@ -1288,9 +1288,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
sampler_output = self.model.sample(logits=logits,
|
||||
sampling_metadata=dummy_metadata)
|
||||
|
||||
try:
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits, sampling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up sampler with "
|
||||
f"{num_reqs} dummy requests. Please try lowering "
|
||||
"`max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
return sampler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
|
||||
@ -221,21 +221,11 @@ class Worker(WorkerBase):
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
if get_pp_group().is_last_rank:
|
||||
try:
|
||||
max_num_reqs = min(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up sampler. "
|
||||
"Please try lowering `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from None
|
||||
else:
|
||||
raise e
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user