mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 17:29:44 +08:00
[Misc] Move functions into PoolingMetadata (#30027)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5430e110c0
commit
68eb5c8d97
@ -64,42 +64,6 @@ class PoolingParamsUpdate:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
def get_prompt_lens(
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return pooling_metadata.prompt_lens
|
||||
|
||||
|
||||
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
|
||||
assert pooling_metadata.prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||
)
|
||||
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
|
||||
|
||||
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
return pooling_params
|
||||
|
||||
|
||||
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
tasks: list[PoolingTask] = [
|
||||
task
|
||||
for pooling_param in pooling_params
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
assert len(pooling_params) == len(tasks)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||
@ -466,7 +430,7 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||
@ -606,7 +570,7 @@ class ClassifierPooler(Pooler):
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
|
||||
if len(set(flags)) == 1:
|
||||
@ -704,7 +668,7 @@ class AllPooler(Pooler):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
@ -724,11 +688,11 @@ class StepPooler(Pooler):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
@ -757,7 +721,7 @@ class StepPooler(Pooler):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
@ -794,7 +758,7 @@ class DispatchPooler(Pooler):
|
||||
|
||||
outputs = list[torch.Tensor]()
|
||||
offset = 0
|
||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||
for task, group in groupby(pooling_metadata.tasks):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
raise ValueError(
|
||||
f"Unsupported task: {task} "
|
||||
|
||||
@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
|
||||
PoolerHead,
|
||||
PoolerNormalize,
|
||||
PoolingParamsUpdate,
|
||||
get_prompt_lens,
|
||||
get_prompt_token_ids,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.tasks import PoolingTask
|
||||
@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in get_prompt_token_ids(pooling_metadata)
|
||||
for token_ids in pooling_metadata.get_prompt_token_ids()
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
@ -40,6 +41,18 @@ class PoolingMetadata:
|
||||
pooling_params: list[PoolingParams]
|
||||
pooling_cursor: PoolingCursor | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
pooling_params = self.pooling_params
|
||||
|
||||
tasks: list[PoolingTask] = [
|
||||
task
|
||||
for pooling_param in pooling_params
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
assert len(pooling_params) == len(tasks)
|
||||
|
||||
self.tasks = tasks
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
prompt_lens=self.prompt_lens[indices],
|
||||
@ -52,6 +65,14 @@ class PoolingMetadata:
|
||||
else self.pooling_cursor[indices],
|
||||
)
|
||||
|
||||
def get_prompt_token_ids(self) -> list[torch.Tensor]:
|
||||
prompt_token_ids = self.prompt_token_ids
|
||||
assert prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||
)
|
||||
|
||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||
|
||||
def build_pooling_cursor(
|
||||
self, num_scheduled_tokens: list[int], device: torch.device
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user