[Misc] Move functions into PoolingMetadata (#30027)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-04 16:21:19 +08:00 committed by GitHub
parent 5430e110c0
commit 68eb5c8d97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 47 deletions

View File

@ -64,42 +64,6 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids
def get_prompt_lens(
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
return pooling_metadata.prompt_lens
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
assert pooling_metadata.prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
pooling_params = pooling_metadata.pooling_params
return pooling_params
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
pooling_params = get_pooling_params(pooling_metadata)
tasks: list[PoolingTask] = [
task
for pooling_param in pooling_params
if (task := pooling_param.task) is not None
]
assert len(pooling_params) == len(tasks)
return tasks
def get_classification_activation_function(config: PretrainedConfig):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
@ -466,7 +430,7 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = get_pooling_params(pooling_metadata)
pooling_params = pooling_metadata.pooling_params
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
@ -606,7 +570,7 @@ class ClassifierPooler(Pooler):
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = get_pooling_params(pooling_metadata)
pooling_params = pooling_metadata.pooling_params
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
@ -704,7 +668,7 @@ class AllPooler(Pooler):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
@ -724,11 +688,11 @@ class StepPooler(Pooler):
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
pooling_params = pooling_metadata.pooling_params
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
@ -757,7 +721,7 @@ class StepPooler(Pooler):
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
@ -794,7 +758,7 @@ class DispatchPooler(Pooler):
outputs = list[torch.Tensor]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task} "

View File

@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
PoolerHead,
PoolerNormalize,
PoolingParamsUpdate,
get_prompt_lens,
get_prompt_token_ids,
)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in get_prompt_token_ids(pooling_metadata)
for token_ids in pooling_metadata.get_prompt_token_ids()
],
device="cpu",
)

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
import torch
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.platform_utils import is_pin_memory_available
pin_memory = is_pin_memory_available()
@ -40,6 +41,18 @@ class PoolingMetadata:
pooling_params: list[PoolingParams]
pooling_cursor: PoolingCursor | None = None
def __post_init__(self) -> None:
pooling_params = self.pooling_params
tasks: list[PoolingTask] = [
task
for pooling_param in pooling_params
if (task := pooling_param.task) is not None
]
assert len(pooling_params) == len(tasks)
self.tasks = tasks
def __getitem__(self, indices: slice):
return PoolingMetadata(
prompt_lens=self.prompt_lens[indices],
@ -52,6 +65,14 @@ class PoolingMetadata:
else self.pooling_cursor[indices],
)
def get_prompt_token_ids(self) -> list[torch.Tensor]:
prompt_token_ids = self.prompt_token_ids
assert prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def build_pooling_cursor(
self, num_scheduled_tokens: list[int], device: torch.device
):