[Performance] V1 Pooling Models E2E Performance Optimization (#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-08-21 21:26:09 +08:00 committed by GitHub
parent 5cc54f7c5b
commit d70a16625d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 162 additions and 168 deletions

View File

@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import resolve_obj_by_qualname
from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingCursor
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
# Pooling models D2H & synchronize occurs here
if isinstance(all_data, list):
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
else:
all_data = all_data.to("cpu", non_blocking=True)
current_stream().synchronize()
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)
@ -231,40 +239,21 @@ class PoolingMethod(nn.Module, ABC):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise NotImplementedError
@abstractmethod
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
return [
self.forward_one(h, prompt_len)
for h, prompt_len in zip(hidden_states, prompt_lens)
]
return self.forward_all(hidden_states, prompt_lens)
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
class CLSPool(PoolingMethod):
@ -272,24 +261,15 @@ class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with CLS pooling"
return hidden_states[0]
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
assert not pooling_cursor.is_partial_prefill(), \
"partial prefill not supported with CLS pooling"
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(PoolingMethod):
@ -297,20 +277,12 @@ class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return hidden_states[-1]
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
return hidden_states[last_token_flat_indices]
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
@ -318,22 +290,19 @@ class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
assert not pooling_cursor.is_partial_prefill(), \
"partial prefill not supported with ALL pooling"
hidden_states_lst = list(
hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()))
return [hidden_states_lst[i] for i in pooling_cursor.index]
class MeanPool(PoolingMethod):
@ -341,31 +310,25 @@ class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with MEAN pooling"
return hidden_states.mean(dim=0, dtype=torch.float32)
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> Union[list[torch.Tensor], torch.Tensor]:
assert not pooling_cursor.is_partial_prefill(), \
"partial prefill not supported with MEAN pooling"
prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
non_blocking=True)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
return (cumsum[end_indices - 1] - cumsum[start_indices] +
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (cumsum[end_indices] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
@ -477,6 +440,10 @@ class EmbeddingPoolerHead(PoolerHead):
pooling_params = get_pooling_params(pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
dimensions_list = [
pooling_param.dimensions for pooling_param in pooling_params
@ -667,6 +634,10 @@ class ClassifierPooler(Pooler):
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
if self.classifier is not None:
# apply classifier once on the full batch if possible
if isinstance(pooled_data, torch.Tensor):
@ -717,12 +688,6 @@ class DispatchPooler(Pooler):
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
if isinstance(hidden_states, list):
hidden_states_lst = hidden_states
else:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
outputs = list[PoolingSequenceGroupOutput]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
@ -733,7 +698,7 @@ class DispatchPooler(Pooler):
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states_lst[offset:offset + num_items],
hidden_states,
pooling_metadata[offset:offset + num_items],
)

View File

@ -528,9 +528,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor,
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
ids_mask = torch.ones(input_ids.shape,
dtype=torch.int32,
device=input_ids.device) << TOKEN_TYPE_SHIFT
ids_mask = torch.ones_like(input_ids,
dtype=torch.int32,
device=input_ids.device) << TOKEN_TYPE_SHIFT
tokens_mask = ids_mask.bitwise_not()
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

View File

@ -9,7 +9,6 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
def forward(
self,
@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
@ -233,58 +232,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
intermediate_tensors=intermediate_tensors)
# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
padding_idx,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:
seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))
if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())
offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
# vllm does not use padding tokens, let's make things simpler
position_ids += padding_idx + 1

View File

@ -2,12 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional
import torch
from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
class PoolingMetadata:
@ -23,14 +24,15 @@ class PoolingMetadata:
"""
def __init__(
self,
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
) -> None:
self,
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
pooling_cursor: Optional[PoolingCursor] = None) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
def __repr__(self) -> str:
return ("PoolingMetadata("
@ -43,8 +45,17 @@ class PoolingMetadata:
seq_groups=self.seq_groups[indices],
seq_data=dict(list(self.seq_data.items())[indices]),
prompt_lens=self.prompt_lens[indices],
pooling_cursor=None
if self.pooling_cursor is None else self.pooling_cursor[indices],
)
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
device: torch.device):
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
prompt_lens,
device=device)
@dataclass
class PoolingTensors:

View File

@ -6,15 +6,40 @@ from typing import Optional
import torch
from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
pin_memory = is_pin_memory_available()
@dataclass
class PoolingCursor:
index: list[int]
first_token_indices_gpu: torch.Tensor
last_token_indices_gpu: torch.Tensor
prompt_lens_cpu: torch.Tensor
num_scheduled_tokens_cpu: torch.Tensor
def __getitem__(self, indices: slice):
return PoolingCursor(
index=self.index[indices],
first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
)
def is_partial_prefill(self):
return not torch.all(
self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
@dataclass
class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]
pooling_cursor: Optional[PoolingCursor] = None
def __getitem__(self, indices: slice):
return PoolingMetadata(
@ -22,4 +47,31 @@ class PoolingMetadata:
prompt_token_ids=None if self.prompt_token_ids is None else
self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
pooling_cursor=None
if self.pooling_cursor is None else self.pooling_cursor[indices],
)
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
device: torch.device):
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
self.prompt_lens, device)
def build_pooling_cursor(num_scheduled_tokens: list[int],
prompt_lens: torch.Tensor, device: torch.device):
assert len(prompt_lens) == len(num_scheduled_tokens)
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu")
cumsum = torch.zeros(n_seq + 1,
dtype=torch.int64,
pin_memory=pin_memory,
device="cpu")
torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
return PoolingCursor(index=index,
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
num_scheduled_tokens_cpu=num_scheduled_tokens)

View File

@ -713,7 +713,7 @@ class InputBatch:
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)

View File

@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Either all or none of the requests in" \
" a batch must be pooling request"
extracted_hidden_states = list(
torch.split(hidden_states[:num_scheduled_tokens],
num_scheduled_tokens_np.tolist()))
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output = self.model.pooler(
hidden_states=extracted_hidden_states,
pooling_metadata=pooling_metadata)
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
pooler_output: list[Optional[torch.Tensor]] = []
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data.cpu())
pooler_output.append(raw_output.data)
else:
pooler_output.append(None)
@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
req_num_tokens = num_tokens // num_reqs
dummy_prompt_lens = torch.tensor(
[h.shape[0] for h in hidden_states_list],
device=self.device,
num_scheduled_tokens_list,
device="cpu",
)
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooling_params=[dummy_pooling_params] * num_reqs,
)
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
device=hidden_states.device)
try:
return model.pooler(hidden_states=hidden_states_list,
return model.pooler(hidden_states=hidden_states,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,
device=self.device)
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
device=self.device)
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()

View File

@ -149,9 +149,16 @@ class PoolingModelRunner(
if not self.is_driver_worker:
return []
pooling_metadata = model_input.pooling_metadata
assert pooling_metadata is not None
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens=pooling_metadata.prompt_lens,
device=hidden_or_intermediate_states.device)
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=model_input.pooling_metadata)
pooling_metadata=pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(