From d70a16625dc74d9517641aa82f4ae7367854da96 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 21 Aug 2025 21:26:09 +0800 Subject: [PATCH] [Performance] V1 Pooling Models E2E Performance Optimization (#23162) Signed-off-by: wang.yuqi --- vllm/model_executor/layers/pooler.py | 133 +++++++++--------------- vllm/model_executor/models/bert.py | 6 +- vllm/model_executor/models/roberta.py | 65 ++---------- vllm/model_executor/pooling_metadata.py | 23 ++-- vllm/v1/pool/metadata.py | 56 +++++++++- vllm/v1/worker/gpu_input_batch.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 36 ++++--- vllm/worker/pooling_model_runner.py | 9 +- 8 files changed, 162 insertions(+), 168 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 75e65072b7016..d34fb58cb5cb2 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import current_stream, resolve_obj_by_qualname +from vllm.v1.pool.metadata import PoolingCursor from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] @@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): def build_output( all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + # Pooling models D2H & synchronize occurs here + if isinstance(all_data, list): + all_data = [d.to("cpu", non_blocking=True) for d in all_data] + else: + all_data = all_data.to("cpu", non_blocking=True) + current_stream().synchronize() + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] return PoolerOutput(outputs=all_outputs) @@ -231,40 +239,21 @@ class PoolingMethod(nn.Module, ABC): def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate() - @abstractmethod - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Note: - `prompt_len=None` means `prompt_len=len(hidden_states)`. - """ - raise NotImplementedError - @abstractmethod def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len) - for h, prompt_len in zip(hidden_states, prompt_lens) - ] - - return self.forward_all(hidden_states, prompt_lens) + pooling_cursor = pooling_metadata.pooling_cursor + return self.forward_all(hidden_states, pooling_cursor) class CLSPool(PoolingMethod): @@ -272,24 +261,15 @@ class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with CLS pooling" - - return hidden_states[0] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] - return hidden_states[first_token_flat_indices] + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with CLS pooling" + + return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): @@ -297,20 +277,12 @@ class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return hidden_states[-1] - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 - return hidden_states[last_token_flat_indices] + return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): @@ -318,22 +290,19 @@ class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with ALL pooling" - - return hidden_states - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - return list(hidden_states.split_with_sizes(prompt_lens.tolist())) + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with ALL pooling" + + hidden_states_lst = list( + hidden_states.split( + pooling_cursor.num_scheduled_tokens_cpu.tolist())) + return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): @@ -341,31 +310,25 @@ class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} - def forward_one( - self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ - "partial prefill not supported with MEAN pooling" - - return hidden_states.mean(dim=0, dtype=torch.float32) - def forward_all( self, hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, + pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: + + assert not pooling_cursor.is_partial_prefill(), \ + "partial prefill not supported with MEAN pooling" + + prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, + non_blocking=True) + # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) - start_indices = torch.cat([ - torch.tensor([0], device=hidden_states.device), - torch.cumsum(prompt_lens[:-1], dim=0) - ]) - end_indices = torch.cumsum(prompt_lens, dim=0) - return (cumsum[end_indices - 1] - cumsum[start_indices] + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + return (cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) @@ -477,6 +440,10 @@ class EmbeddingPoolerHead(PoolerHead): pooling_params = get_pooling_params(pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + # for matryoshka representation dimensions_list = [ pooling_param.dimensions for pooling_param in pooling_params @@ -667,6 +634,10 @@ class ClassifierPooler(Pooler): ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_size] + if self.classifier is not None: # apply classifier once on the full batch if possible if isinstance(pooled_data, torch.Tensor): @@ -717,12 +688,6 @@ class DispatchPooler(Pooler): ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - if isinstance(hidden_states, list): - hidden_states_lst = hidden_states - else: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) - outputs = list[PoolingSequenceGroupOutput]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): @@ -733,7 +698,7 @@ class DispatchPooler(Pooler): num_items = len(list(group)) group_output: PoolerOutput = pooler( - hidden_states_lst[offset:offset + num_items], + hidden_states, pooling_metadata[offset:offset + num_items], ) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6638f06f98261..2bd5eb5bb7aa8 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -528,9 +528,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor, def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - ids_mask = torch.ones(input_ids.shape, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT + ids_mask = torch.ones_like(input_ids, + dtype=torch.int32, + device=input_ids.device) << TOKEN_TYPE_SHIFT tokens_mask = ids_mask.bitwise_not() token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 32a4a2c9a2694..49a37342c67fa 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,6 @@ from torch import nn from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id def forward( self, @@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -233,58 +232,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): intermediate_tensors=intermediate_tensors) -# Adapted from transformers -def create_position_ids_from_input_ids(input_ids, - padding_idx, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return incremental_indices.long() + padding_idx - - def replace_roberta_positions(input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int) -> None: - - seq_lens: Optional[torch.Tensor] = None - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: # can be None during warmup - if isinstance(attn_metadata, dict): - attn_metadata = next(iter(attn_metadata.values())) - # TODO: remove "seq_lens_tensor" after V0 is removed - seq_lens = getattr(attn_metadata, "seq_lens_tensor", - getattr(attn_metadata, "seq_lens", None)) - - if seq_lens is not None: - assert isinstance(seq_lens, torch.Tensor) - - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - token_list = torch.split(input_ids[:torch.sum(seq_lens)], - seq_lens.tolist()) - - offset = 0 - for tokens in token_list: - length = tokens.shape[0] - position_ids[offset:offset+length] = \ - create_position_ids_from_input_ids(tokens, padding_idx) - offset = offset + length + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + # vllm does not use padding tokens, let's make things simpler + position_ids += padding_idx + 1 diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index e6f1ca61dd291..3209879193453 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any +from typing import Any, Optional import torch from vllm.pooling_params import PoolingParams from vllm.utils import is_pin_memory_available +from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor class PoolingMetadata: @@ -23,14 +24,15 @@ class PoolingMetadata: """ def __init__( - self, - seq_groups: list[tuple[list[int], PoolingParams]], - seq_data: dict[int, Any], # Specific data related to sequences - prompt_lens: list[int], - ) -> None: + self, + seq_groups: list[tuple[list[int], PoolingParams]], + seq_data: dict[int, Any], # Specific data related to sequences + prompt_lens: list[int], + pooling_cursor: Optional[PoolingCursor] = None) -> None: self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens + self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor def __repr__(self) -> str: return ("PoolingMetadata(" @@ -43,8 +45,17 @@ class PoolingMetadata: seq_groups=self.seq_groups[indices], seq_data=dict(list(self.seq_data.items())[indices]), prompt_lens=self.prompt_lens[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + prompt_lens = torch.tensor(self.prompt_lens, device="cpu") + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + prompt_lens, + device=device) + @dataclass class PoolingTensors: diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 28af720d05fd1..46506d272e90a 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -6,15 +6,40 @@ from typing import Optional import torch from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + +pin_memory = is_pin_memory_available() + + +@dataclass +class PoolingCursor: + index: list[int] + first_token_indices_gpu: torch.Tensor + last_token_indices_gpu: torch.Tensor + prompt_lens_cpu: torch.Tensor + num_scheduled_tokens_cpu: torch.Tensor + + def __getitem__(self, indices: slice): + return PoolingCursor( + index=self.index[indices], + first_token_indices_gpu=self.first_token_indices_gpu[indices], + last_token_indices_gpu=self.last_token_indices_gpu[indices], + prompt_lens_cpu=self.prompt_lens_cpu[indices], + num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], + ) + + def is_partial_prefill(self): + return not torch.all( + self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" - - prompt_lens: torch.Tensor + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + pooling_cursor: Optional[PoolingCursor] = None def __getitem__(self, indices: slice): return PoolingMetadata( @@ -22,4 +47,31 @@ class PoolingMetadata: prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_cursor=None + if self.pooling_cursor is None else self.pooling_cursor[indices], ) + + def build_pooling_cursor(self, num_scheduled_tokens: list[int], + device: torch.device): + self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, + self.prompt_lens, device) + + +def build_pooling_cursor(num_scheduled_tokens: list[int], + prompt_lens: torch.Tensor, device: torch.device): + assert len(prompt_lens) == len(num_scheduled_tokens) + + n_seq = len(num_scheduled_tokens) + index = list(range(n_seq)) + num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") + cumsum = torch.zeros(n_seq + 1, + dtype=torch.int64, + pin_memory=pin_memory, + device="cpu") + torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + return PoolingCursor(index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 8d08bd7742ffc..154b77ae6301d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -713,7 +713,7 @@ class InputBatch: return PoolingMetadata( prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), + self.num_prompt_tokens[:self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7caa873be4442..43a9888d8ea22 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "Either all or none of the requests in" \ " a batch must be pooling request" - extracted_hidden_states = list( - torch.split(hidden_states[:num_scheduled_tokens], - num_scheduled_tokens_np.tolist())) - + hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] + # Pooling models D2H & synchronize occurs in pooler.py:build_output raw_pooler_output = self.model.pooler( - hidden_states=extracted_hidden_states, - pooling_metadata=pooling_metadata) + hidden_states=hidden_states, pooling_metadata=pooling_metadata) pooler_output: list[Optional[torch.Tensor]] = [] - seq_lens = self.seq_lens[:self.input_batch.num_reqs] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): if seq_len == prompt_len: - pooler_output.append(raw_output.data.cpu()) + pooler_output.append(raw_output.data) else: pooler_output.append(None) @@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - hidden_states_list = list( - torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, + num_scheduled_tokens_list, + device="cpu", ) dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, @@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooling_params=[dummy_pooling_params] * num_reqs, ) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + try: - return model.pooler(hidden_states=hidden_states_list, + return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): @@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_block_table = torch.zeros((num_reqs, 1), dtype=torch.int32, - device=self.device) + pin_memory=self.pin_memory, + device="cpu").to(self.device, + non_blocking=True) dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), dtype=torch.int32, - device=self.device) + pin_memory=self.pin_memory, + device="cpu").to(self.device, + non_blocking=True) group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]() diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index e49783ad9b244..8d8d9b4d0503f 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -149,9 +149,16 @@ class PoolingModelRunner( if not self.is_driver_worker: return [] + pooling_metadata = model_input.pooling_metadata + assert pooling_metadata is not None + + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens=pooling_metadata.prompt_lens, + device=hidden_or_intermediate_states.device) + return [ self.model.pooler(hidden_states=hidden_or_intermediate_states, - pooling_metadata=model_input.pooling_metadata) + pooling_metadata=pooling_metadata) ] def make_model_input_from_broadcasted_tensor_dict(