From 46ad73955a47a69411e37c7871480e834dbfeb22 Mon Sep 17 00:00:00 2001 From: yyzxw <34639446+yyzxw@users.noreply.github.com> Date: Mon, 13 Oct 2025 11:56:21 +0800 Subject: [PATCH] [FIX] Throwing an exception when the model does not support pool tasks (#25840) (#25855) Signed-off-by: zxw <1020938856@qq.com> Co-authored-by: wang.yuqi --- vllm/model_executor/models/adapters.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 22 +++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 32073cb88de40..6d035f93dd9b7 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -399,6 +399,9 @@ def as_reward_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + from .interfaces_base import default_pooling_type + + @default_pooling_type("ALL") class ModelForReward(_create_pooling_model_cls(cls)): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0d99597fa641f..09e66a12d14f2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3622,8 +3622,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states: torch.Tensor, ) -> PoolerOutput: # Find the task that has the largest output for subsequent steps + supported_pooling_tasks = self.get_supported_pooling_tasks() + + if not supported_pooling_tasks: + if self.scheduler_config.chunked_prefill_enabled: + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks with chunked prefill enabled. " + "Please add --no-enable-chunked-prefill to your " + "config or CLI args. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) + else: + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) + output_size = dict[PoolingTask, float]() - for task in self.get_supported_pooling_tasks(): + for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) output_size[task] = sum(o.nbytes for o in output)