mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 09:21:48 +08:00
[Performance] V1 Pooling Models E2E Performance Optimization (#23162)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
5cc54f7c5b
commit
d70a16625d
@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.utils import current_stream, resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingCursor
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
|
||||
def build_output(
|
||||
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
|
||||
# Pooling models D2H & synchronize occurs here
|
||||
if isinstance(all_data, list):
|
||||
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
|
||||
else:
|
||||
all_data = all_data.to("cpu", non_blocking=True)
|
||||
current_stream().synchronize()
|
||||
|
||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||
return PoolerOutput(outputs=all_outputs)
|
||||
|
||||
@ -231,40 +239,21 @@ class PoolingMethod(nn.Module, ABC):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Note:
|
||||
`prompt_len=None` means `prompt_len=len(hidden_states)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
return [
|
||||
self.forward_one(h, prompt_len)
|
||||
for h, prompt_len in zip(hidden_states, prompt_lens)
|
||||
]
|
||||
|
||||
return self.forward_all(hidden_states, prompt_lens)
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
@ -272,24 +261,15 @@ class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with CLS pooling"
|
||||
|
||||
return hidden_states[0]
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
return hidden_states[first_token_flat_indices]
|
||||
assert not pooling_cursor.is_partial_prefill(), \
|
||||
"partial prefill not supported with CLS pooling"
|
||||
|
||||
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
@ -297,20 +277,12 @@ class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states[-1]
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
return hidden_states[last_token_flat_indices]
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
@ -318,22 +290,19 @@ class AllPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with ALL pooling"
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
return list(hidden_states.split_with_sizes(prompt_lens.tolist()))
|
||||
|
||||
assert not pooling_cursor.is_partial_prefill(), \
|
||||
"partial prefill not supported with ALL pooling"
|
||||
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(
|
||||
pooling_cursor.num_scheduled_tokens_cpu.tolist()))
|
||||
return [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
@ -341,31 +310,25 @@ class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"encode", "embed", "classify", "score"}
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
|
||||
return hidden_states.mean(dim=0, dtype=torch.float32)
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
|
||||
assert not pooling_cursor.is_partial_prefill(), \
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device,
|
||||
non_blocking=True)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
|
||||
start_indices = torch.cat([
|
||||
torch.tensor([0], device=hidden_states.device),
|
||||
torch.cumsum(prompt_lens[:-1], dim=0)
|
||||
])
|
||||
end_indices = torch.cumsum(prompt_lens, dim=0)
|
||||
return (cumsum[end_indices - 1] - cumsum[start_indices] +
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
return (cumsum[end_indices] - cumsum[start_indices] +
|
||||
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
@ -477,6 +440,10 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
|
||||
pooling_params = get_pooling_params(pooling_metadata)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions for pooling_param in pooling_params
|
||||
@ -667,6 +634,10 @@ class ClassifierPooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
# apply classifier once on the full batch if possible
|
||||
if isinstance(pooled_data, torch.Tensor):
|
||||
@ -717,12 +688,6 @@ class DispatchPooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
hidden_states_lst = hidden_states
|
||||
else:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
hidden_states_lst = list(hidden_states.split(prompt_lens.tolist()))
|
||||
|
||||
outputs = list[PoolingSequenceGroupOutput]()
|
||||
offset = 0
|
||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||
@ -733,7 +698,7 @@ class DispatchPooler(Pooler):
|
||||
|
||||
num_items = len(list(group))
|
||||
group_output: PoolerOutput = pooler(
|
||||
hidden_states_lst[offset:offset + num_items],
|
||||
hidden_states,
|
||||
pooling_metadata[offset:offset + num_items],
|
||||
)
|
||||
|
||||
|
||||
@ -528,9 +528,9 @@ def _encode_token_type_ids(input_ids: torch.Tensor,
|
||||
|
||||
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
ids_mask = torch.ones(input_ids.shape,
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device) << TOKEN_TYPE_SHIFT
|
||||
ids_mask = torch.ones_like(input_ids,
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device) << TOKEN_TYPE_SHIFT
|
||||
tokens_mask = ids_mask.bitwise_not()
|
||||
|
||||
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
|
||||
|
||||
@ -9,7 +9,6 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
@ -233,58 +232,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
def create_position_ids_from_input_ids(input_ids,
|
||||
padding_idx,
|
||||
past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers.
|
||||
Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
|
||||
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def replace_roberta_positions(input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
padding_idx: int) -> None:
|
||||
|
||||
seq_lens: Optional[torch.Tensor] = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None: # can be None during warmup
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = next(iter(attn_metadata.values()))
|
||||
# TODO: remove "seq_lens_tensor" after V0 is removed
|
||||
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
|
||||
getattr(attn_metadata, "seq_lens", None))
|
||||
|
||||
if seq_lens is not None:
|
||||
assert isinstance(seq_lens, torch.Tensor)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
|
||||
seq_lens.tolist())
|
||||
|
||||
offset = 0
|
||||
for tokens in token_list:
|
||||
length = tokens.shape[0]
|
||||
position_ids[offset:offset+length] = \
|
||||
create_position_ids_from_input_ids(tokens, padding_idx)
|
||||
offset = offset + length
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
# vllm does not use padding tokens, let's make things simpler
|
||||
position_ids += padding_idx + 1
|
||||
|
||||
@ -2,12 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
|
||||
|
||||
|
||||
class PoolingMetadata:
|
||||
@ -23,14 +24,15 @@ class PoolingMetadata:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: list[tuple[list[int], PoolingParams]],
|
||||
seq_data: dict[int, Any], # Specific data related to sequences
|
||||
prompt_lens: list[int],
|
||||
) -> None:
|
||||
self,
|
||||
seq_groups: list[tuple[list[int], PoolingParams]],
|
||||
seq_data: dict[int, Any], # Specific data related to sequences
|
||||
prompt_lens: list[int],
|
||||
pooling_cursor: Optional[PoolingCursor] = None) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("PoolingMetadata("
|
||||
@ -43,8 +45,17 @@ class PoolingMetadata:
|
||||
seq_groups=self.seq_groups[indices],
|
||||
seq_data=dict(list(self.seq_data.items())[indices]),
|
||||
prompt_lens=self.prompt_lens[indices],
|
||||
pooling_cursor=None
|
||||
if self.pooling_cursor is None else self.pooling_cursor[indices],
|
||||
)
|
||||
|
||||
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
|
||||
device: torch.device):
|
||||
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
|
||||
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
|
||||
prompt_lens,
|
||||
device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingTensors:
|
||||
|
||||
@ -6,15 +6,40 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingCursor:
|
||||
index: list[int]
|
||||
first_token_indices_gpu: torch.Tensor
|
||||
last_token_indices_gpu: torch.Tensor
|
||||
prompt_lens_cpu: torch.Tensor
|
||||
num_scheduled_tokens_cpu: torch.Tensor
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingCursor(
|
||||
index=self.index[indices],
|
||||
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
||||
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
||||
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
||||
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
||||
)
|
||||
|
||||
def is_partial_prefill(self):
|
||||
return not torch.all(
|
||||
self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingMetadata:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor
|
||||
prompt_lens: torch.Tensor # CPU Tensor
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
pooling_params: list[PoolingParams]
|
||||
pooling_cursor: Optional[PoolingCursor] = None
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
@ -22,4 +47,31 @@ class PoolingMetadata:
|
||||
prompt_token_ids=None if self.prompt_token_ids is None else
|
||||
self.prompt_token_ids[indices],
|
||||
pooling_params=self.pooling_params[indices],
|
||||
pooling_cursor=None
|
||||
if self.pooling_cursor is None else self.pooling_cursor[indices],
|
||||
)
|
||||
|
||||
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
|
||||
device: torch.device):
|
||||
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
|
||||
self.prompt_lens, device)
|
||||
|
||||
|
||||
def build_pooling_cursor(num_scheduled_tokens: list[int],
|
||||
prompt_lens: torch.Tensor, device: torch.device):
|
||||
assert len(prompt_lens) == len(num_scheduled_tokens)
|
||||
|
||||
n_seq = len(num_scheduled_tokens)
|
||||
index = list(range(n_seq))
|
||||
num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu")
|
||||
cumsum = torch.zeros(n_seq + 1,
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu")
|
||||
torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:])
|
||||
cumsum = cumsum.to(device, non_blocking=True)
|
||||
return PoolingCursor(index=index,
|
||||
first_token_indices_gpu=cumsum[:n_seq],
|
||||
last_token_indices_gpu=cumsum[1:] - 1,
|
||||
prompt_lens_cpu=prompt_lens,
|
||||
num_scheduled_tokens_cpu=num_scheduled_tokens)
|
||||
|
||||
@ -713,7 +713,7 @@ class InputBatch:
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"Either all or none of the requests in" \
|
||||
" a batch must be pooling request"
|
||||
|
||||
extracted_hidden_states = list(
|
||||
torch.split(hidden_states[:num_scheduled_tokens],
|
||||
num_scheduled_tokens_np.tolist()))
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
||||
|
||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=extracted_hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data.cpu())
|
||||
pooler_output.append(raw_output.data)
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
|
||||
@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
[h.shape[0] for h in hidden_states_list],
|
||||
device=self.device,
|
||||
num_scheduled_tokens_list,
|
||||
device="cpu",
|
||||
)
|
||||
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
|
||||
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
||||
device=hidden_states.device)
|
||||
|
||||
try:
|
||||
return model.pooler(hidden_states=hidden_states_list,
|
||||
return model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||
|
||||
|
||||
@ -149,9 +149,16 @@ class PoolingModelRunner(
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
pooling_metadata = model_input.pooling_metadata
|
||||
assert pooling_metadata is not None
|
||||
|
||||
pooling_metadata.build_pooling_cursor(
|
||||
num_scheduled_tokens=pooling_metadata.prompt_lens,
|
||||
device=hidden_or_intermediate_states.device)
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
pooling_metadata=pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user