mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 13:03:06 +08:00
Remove runtime checks based on pooling params (#24051)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
04d0c60770
commit
d59c986444
@ -704,17 +704,12 @@ class InputBatch:
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
|
||||
@property
|
||||
def pooling_metadata(self) -> PoolingMetadata:
|
||||
if len(self.pooling_params) == 0:
|
||||
pooling_params = []
|
||||
else:
|
||||
# Note, for now this assumes that all request in the batch
|
||||
# are either sampling or pooling requests
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
pooling_params = [
|
||||
self.pooling_params[req_id] for req_id in self.req_ids
|
||||
]
|
||||
def get_pooling_params(self) -> list[PoolingParams]:
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
||||
|
||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||
pooling_params = self.get_pooling_params()
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
|
||||
@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.is_pooling_model = (model_config.runner_type == 'pooling')
|
||||
self.is_multimodal_raw_input_only_model = (
|
||||
model_config.is_multimodal_raw_input_only_model)
|
||||
|
||||
@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _init_model_kwargs(self, num_tokens: int):
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
num_pooling_reqs = len(self.input_batch.pooling_params)
|
||||
|
||||
if num_pooling_reqs == 0:
|
||||
if not self.is_pooling_model:
|
||||
return model_kwargs
|
||||
|
||||
# This does nontrivial work.
|
||||
pooling_params = self.input_batch.pooling_metadata.pooling_params
|
||||
|
||||
assert num_pooling_reqs == num_reqs
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
pooling_params = self.input_batch.get_pooling_params()
|
||||
|
||||
token_type_id_requests = dict[int, Any]()
|
||||
for i, param in enumerate(pooling_params):
|
||||
@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
generator = None
|
||||
|
||||
if pooling_params:
|
||||
if self.is_pooling_model:
|
||||
assert pooling_params is not None
|
||||
task = pooling_params.task
|
||||
assert task is not None, "You did not set `task` in the API"
|
||||
|
||||
@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
" a batch must be pooling request"
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
||||
@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
if self.is_pooling_model:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, kv_connector_output)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user