Remove runtime checks based on pooling params (#24051)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-09-02 00:54:37 -03:00 committed by GitHub
parent 04d0c60770
commit d59c986444
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 23 deletions

View File

@ -704,17 +704,12 @@ class InputBatch:
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata(
prompt_lens=torch.from_numpy(

View File

@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None
self.is_pooling_model = (model_config.runner_type == 'pooling')
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model)
@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0:
if not self.is_pooling_model:
return model_kwargs
# This does nontrivial work.
pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs
num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.get_pooling_params()
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
generator = None
if pooling_params:
if self.is_pooling_model:
assert pooling_params is not None
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
" a batch must be pooling request"
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output)