mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 15:57:00 +08:00
[Misc] Move functions into PoolingMetadata (#30027)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5430e110c0
commit
68eb5c8d97
@ -64,42 +64,6 @@ class PoolingParamsUpdate:
|
|||||||
params.requires_token_ids = self.requires_token_ids
|
params.requires_token_ids = self.requires_token_ids
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_lens(
|
|
||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return pooling_metadata.prompt_lens
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
|
|
||||||
assert pooling_metadata.prompt_token_ids is not None, (
|
|
||||||
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
pooling_metadata.prompt_token_ids[i, :num]
|
|
||||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
|
|
||||||
pooling_params = pooling_metadata.pooling_params
|
|
||||||
return pooling_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
|
||||||
|
|
||||||
tasks: list[PoolingTask] = [
|
|
||||||
task
|
|
||||||
for pooling_param in pooling_params
|
|
||||||
if (task := pooling_param.task) is not None
|
|
||||||
]
|
|
||||||
assert len(pooling_params) == len(tasks)
|
|
||||||
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
|
|
||||||
def get_classification_activation_function(config: PretrainedConfig):
|
def get_classification_activation_function(config: PretrainedConfig):
|
||||||
# Implement alignment with transformers ForSequenceClassificationLoss
|
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||||
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||||
@ -466,7 +430,7 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
pooled_data = self.projector(pooled_data)
|
pooled_data = self.projector(pooled_data)
|
||||||
# pooled_data shape: [batchsize, embedding_dimension]
|
# pooled_data shape: [batchsize, embedding_dimension]
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = pooling_metadata.pooling_params
|
||||||
|
|
||||||
# for matryoshka representation
|
# for matryoshka representation
|
||||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||||
@ -606,7 +570,7 @@ class ClassifierPooler(Pooler):
|
|||||||
if self.logit_bias is not None:
|
if self.logit_bias is not None:
|
||||||
pooled_data -= self.logit_bias
|
pooled_data -= self.logit_bias
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = pooling_metadata.pooling_params
|
||||||
flags = [p.use_activation for p in pooling_params]
|
flags = [p.use_activation for p in pooling_params]
|
||||||
|
|
||||||
if len(set(flags)) == 1:
|
if len(set(flags)) == 1:
|
||||||
@ -704,7 +668,7 @@ class AllPooler(Pooler):
|
|||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = pooling_metadata.pooling_params
|
||||||
assert len(pooled_data) == len(pooling_params)
|
assert len(pooled_data) == len(pooling_params)
|
||||||
|
|
||||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||||
@ -724,11 +688,11 @@ class StepPooler(Pooler):
|
|||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> torch.Tensor | list[torch.Tensor]:
|
) -> torch.Tensor | list[torch.Tensor]:
|
||||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||||
|
|
||||||
pooled_data = list[torch.Tensor]()
|
pooled_data = list[torch.Tensor]()
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = pooling_metadata.pooling_params
|
||||||
|
|
||||||
for data, token_id, pooling_param in zip(
|
for data, token_id, pooling_param in zip(
|
||||||
pooled_data_lst, prompt_token_ids, pooling_params
|
pooled_data_lst, prompt_token_ids, pooling_params
|
||||||
@ -757,7 +721,7 @@ class StepPooler(Pooler):
|
|||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
pooling_params = pooling_metadata.pooling_params
|
||||||
assert len(pooled_data) == len(pooling_params)
|
assert len(pooled_data) == len(pooling_params)
|
||||||
|
|
||||||
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||||
@ -794,7 +758,7 @@ class DispatchPooler(Pooler):
|
|||||||
|
|
||||||
outputs = list[torch.Tensor]()
|
outputs = list[torch.Tensor]()
|
||||||
offset = 0
|
offset = 0
|
||||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
for task, group in groupby(pooling_metadata.tasks):
|
||||||
if not (pooler := poolers_by_task.get(task)):
|
if not (pooler := poolers_by_task.get(task)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported task: {task} "
|
f"Unsupported task: {task} "
|
||||||
|
|||||||
@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
|
|||||||
PoolerHead,
|
PoolerHead,
|
||||||
PoolerNormalize,
|
PoolerNormalize,
|
||||||
PoolingParamsUpdate,
|
PoolingParamsUpdate,
|
||||||
get_prompt_lens,
|
|
||||||
get_prompt_token_ids,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
|
|||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> list[torch.Tensor] | torch.Tensor:
|
||||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = pooling_metadata.prompt_lens
|
||||||
instr_lens = torch.tensor(
|
instr_lens = torch.tensor(
|
||||||
[
|
[
|
||||||
self._get_instruction_len(token_ids.cpu().numpy())
|
self._get_instruction_len(token_ids.cpu().numpy())
|
||||||
for token_ids in get_prompt_token_ids(pooling_metadata)
|
for token_ids in pooling_metadata.get_prompt_token_ids()
|
||||||
],
|
],
|
||||||
device="cpu",
|
device="cpu",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
|
|
||||||
pin_memory = is_pin_memory_available()
|
pin_memory = is_pin_memory_available()
|
||||||
@ -40,6 +41,18 @@ class PoolingMetadata:
|
|||||||
pooling_params: list[PoolingParams]
|
pooling_params: list[PoolingParams]
|
||||||
pooling_cursor: PoolingCursor | None = None
|
pooling_cursor: PoolingCursor | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
pooling_params = self.pooling_params
|
||||||
|
|
||||||
|
tasks: list[PoolingTask] = [
|
||||||
|
task
|
||||||
|
for pooling_param in pooling_params
|
||||||
|
if (task := pooling_param.task) is not None
|
||||||
|
]
|
||||||
|
assert len(pooling_params) == len(tasks)
|
||||||
|
|
||||||
|
self.tasks = tasks
|
||||||
|
|
||||||
def __getitem__(self, indices: slice):
|
def __getitem__(self, indices: slice):
|
||||||
return PoolingMetadata(
|
return PoolingMetadata(
|
||||||
prompt_lens=self.prompt_lens[indices],
|
prompt_lens=self.prompt_lens[indices],
|
||||||
@ -52,6 +65,14 @@ class PoolingMetadata:
|
|||||||
else self.pooling_cursor[indices],
|
else self.pooling_cursor[indices],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_prompt_token_ids(self) -> list[torch.Tensor]:
|
||||||
|
prompt_token_ids = self.prompt_token_ids
|
||||||
|
assert prompt_token_ids is not None, (
|
||||||
|
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||||
|
|
||||||
def build_pooling_cursor(
|
def build_pooling_cursor(
|
||||||
self, num_scheduled_tokens: list[int], device: torch.device
|
self, num_scheduled_tokens: list[int], device: torch.device
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user