From 68eb5c8d970a453a440776211f8dbff215fb40c3 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 4 Dec 2025 16:21:19 +0800 Subject: [PATCH] [Misc] Move functions into `PoolingMetadata` (#30027) Signed-off-by: DarkLight1337 --- vllm/model_executor/layers/pooler.py | 50 ++++------------------------ vllm/model_executor/models/gritlm.py | 6 ++-- vllm/v1/pool/metadata.py | 21 ++++++++++++ 3 files changed, 30 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 7dd02e32ff211..185e03e5f3bd7 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -64,42 +64,6 @@ class PoolingParamsUpdate: params.requires_token_ids = self.requires_token_ids -def get_prompt_lens( - hidden_states: torch.Tensor | list[torch.Tensor], - pooling_metadata: PoolingMetadata, -) -> torch.Tensor: - return pooling_metadata.prompt_lens - - -def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`" - ) - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - - -def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - pooling_params = pooling_metadata.pooling_params - return pooling_params - - -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: - pooling_params = get_pooling_params(pooling_metadata) - - tasks: list[PoolingTask] = [ - task - for pooling_param in pooling_params - if (task := pooling_param.task) is not None - ] - assert len(pooling_params) == len(tasks) - - return tasks - - def get_classification_activation_function(config: PretrainedConfig): # Implement alignment with transformers ForSequenceClassificationLoss # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 @@ -466,7 +430,7 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params # for matryoshka representation dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] @@ -606,7 +570,7 @@ class ClassifierPooler(Pooler): if self.logit_bias is not None: pooled_data -= self.logit_bias - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params flags = [p.use_activation for p in pooling_params] if len(set(flags)) == 1: @@ -704,7 +668,7 @@ class AllPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -724,11 +688,11 @@ class StepPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> torch.Tensor | list[torch.Tensor]: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = get_prompt_token_ids(pooling_metadata) + prompt_token_ids = pooling_metadata.get_prompt_token_ids() pooled_data = list[torch.Tensor]() - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params @@ -757,7 +721,7 @@ class StepPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -794,7 +758,7 @@ class DispatchPooler(Pooler): outputs = list[torch.Tensor]() offset = 0 - for task, group in groupby(get_tasks(pooling_metadata)): + for task, group in groupby(pooling_metadata.tasks): if not (pooler := poolers_by_task.get(task)): raise ValueError( f"Unsupported task: {task} " diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 550e8b014d5e7..2aba626a7c737 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import ( PoolerHead, PoolerNormalize, PoolingParamsUpdate, - get_prompt_lens, - get_prompt_token_ids, ) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask @@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module): hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> list[torch.Tensor] | torch.Tensor: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + prompt_lens = pooling_metadata.prompt_lens instr_lens = torch.tensor( [ self._get_instruction_len(token_ids.cpu().numpy()) - for token_ids in get_prompt_token_ids(pooling_metadata) + for token_ids in pooling_metadata.get_prompt_token_ids() ], device="cpu", ) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 7bd2c7415dafe..9ee588ea44ca4 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch from vllm.pooling_params import PoolingParams +from vllm.tasks import PoolingTask from vllm.utils.platform_utils import is_pin_memory_available pin_memory = is_pin_memory_available() @@ -40,6 +41,18 @@ class PoolingMetadata: pooling_params: list[PoolingParams] pooling_cursor: PoolingCursor | None = None + def __post_init__(self) -> None: + pooling_params = self.pooling_params + + tasks: list[PoolingTask] = [ + task + for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + self.tasks = tasks + def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], @@ -52,6 +65,14 @@ class PoolingMetadata: else self.pooling_cursor[indices], ) + def get_prompt_token_ids(self) -> list[torch.Tensor]: + prompt_token_ids = self.prompt_token_ids + assert prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) + + return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] + def build_pooling_cursor( self, num_scheduled_tokens: list[int], device: torch.device ):