From d59c986444a701b39369453eff0a8ba324bd565f Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Tue, 2 Sep 2025 00:54:37 -0300 Subject: [PATCH] Remove runtime checks based on pooling params (#24051) Signed-off-by: Max de Bayser --- vllm/v1/worker/gpu_input_batch.py | 17 ++++++----------- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++------------ 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f4c2f45df5954..ef5a7e39a5b16 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -704,17 +704,12 @@ class InputBatch: logitsprocs=self.logitsprocs, ) - @property - def pooling_metadata(self) -> PoolingMetadata: - if len(self.pooling_params) == 0: - pooling_params = [] - else: - # Note, for now this assumes that all request in the batch - # are either sampling or pooling requests - assert len(self.req_ids) == len(self.pooling_params) - pooling_params = [ - self.pooling_params[req_id] for req_id in self.req_ids - ] + def get_pooling_params(self) -> list[PoolingParams]: + assert len(self.req_ids) == len(self.pooling_params) + return [self.pooling_params[req_id] for req_id in self.req_ids] + + def get_pooling_metadata(self) -> PoolingMetadata: + pooling_params = self.get_pooling_params() return PoolingMetadata( prompt_lens=torch.from_numpy( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08e13ab887bf9..96dafd6add679 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - self.is_pooling_model = model_config.pooler_config is not None + self.is_pooling_model = (model_config.runner_type == 'pooling') self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model) @@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() - num_reqs = self.input_batch.num_reqs - num_pooling_reqs = len(self.input_batch.pooling_params) - - if num_pooling_reqs == 0: + if not self.is_pooling_model: return model_kwargs - # This does nontrivial work. - pooling_params = self.input_batch.pooling_metadata.pooling_params - - assert num_pooling_reqs == num_reqs + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): @@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: generator = None - if pooling_params: + if self.is_pooling_model: + assert pooling_params is not None task = pooling_params.task assert task is not None, "You did not set `task` in the API" @@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): " a batch must be pooling request" hidden_states = hidden_states[:num_scheduled_tokens] - pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), device=hidden_states.device) seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] @@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): all_gather_group=get_tp_group()) logits = None else: - if self.input_batch.pooling_params: + if self.is_pooling_model: return self._pool(hidden_states, num_scheduled_tokens, num_scheduled_tokens_np, kv_connector_output)