diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df7ca70924bf5..c2a976108e4d4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1288,9 +1288,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): allowed_token_ids_mask=None, bad_words_token_ids={}, ) - sampler_output = self.model.sample(logits=logits, - sampling_metadata=dummy_metadata) - + try: + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up sampler with " + f"{num_reqs} dummy requests. Please try lowering " + "`max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e return sampler_output def profile_run(self) -> None: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5527a105f8670..241869e35c620 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -221,21 +221,11 @@ class Worker(WorkerBase): # NOTE: This is called after `capture_model` on purpose to prevent # memory buffers from being cleared by `torch.cuda.empty_cache`. if get_pp_group().is_last_rank: - try: - max_num_reqs = min( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens) - self.model_runner._dummy_sampler_run( - hidden_states=self.model_runner._dummy_run( - num_tokens=max_num_reqs)) - except RuntimeError as e: - if 'out of memory' in str(e): - raise RuntimeError( - "CUDA out of memory occurred when warming up sampler. " - "Please try lowering `gpu_memory_utilization` when " - "initializing the engine.") from None - else: - raise e + max_num_reqs = min(self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens) + self.model_runner._dummy_sampler_run( + hidden_states=self.model_runner._dummy_run( + num_tokens=max_num_reqs)) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling.