mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 10:25:35 +08:00
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
208 lines
8.9 KiB
Python
208 lines
8.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import dataclasses
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|
from vllm.multimodal import MultiModalKwargs
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
|
SequenceGroupMetadata)
|
|
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
|
|
ModelInputForGPUBuilder)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
|
"""
|
|
Used by the PoolingModelRunner.
|
|
"""
|
|
pooling_metadata: Optional["PoolingMetadata"] = None
|
|
|
|
|
|
class PoolingModelRunner(
|
|
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
|
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
|
ModelInputForGPUWithPoolingMetadata)
|
|
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
kv_cache_dtype: Optional[str] = "auto",
|
|
is_driver_worker: bool = False,
|
|
):
|
|
super().__init__(vllm_config=vllm_config,
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
is_driver_worker=is_driver_worker)
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
model_input: ModelInputForGPUWithPoolingMetadata,
|
|
kv_caches: List[torch.Tensor],
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
num_steps: int = 1,
|
|
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
|
if num_steps > 1:
|
|
raise ValueError(
|
|
"PoolingModelRunner does not support multi-step execution.")
|
|
|
|
if self.lora_config:
|
|
assert model_input.lora_requests is not None
|
|
assert model_input.lora_mapping is not None
|
|
self.set_active_loras(model_input.lora_requests,
|
|
model_input.lora_mapping)
|
|
|
|
if self.prompt_adapter_config:
|
|
assert model_input.prompt_adapter_requests is not None
|
|
assert model_input.prompt_adapter_mapping is not None
|
|
self.set_active_prompt_adapters(
|
|
model_input.prompt_adapter_requests,
|
|
model_input.prompt_adapter_mapping)
|
|
|
|
# Currently cuda graph is only supported by the decode phase.
|
|
assert model_input.attn_metadata is not None
|
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
|
decode_meta = model_input.attn_metadata.decode_metadata
|
|
virtual_engine = model_input.virtual_engine
|
|
# Pooling models are (ab-)used also to integrate non text models that
|
|
# are not autoregressive (PrithviGeosaptialMAE).
|
|
# These model might not use attention and do not really have a prefill
|
|
# and decode phase. The model input is processed in one shot and both
|
|
# decode_metadata and prefill_metadata would be None for such models.
|
|
# See the PlaceholderAttentionMetadata class.
|
|
# TODO: Figure out if cuda_graph is of any use for these models and
|
|
# explore how to leverage it.
|
|
if (prefill_meta is None and decode_meta is not None
|
|
and decode_meta.use_cuda_graph):
|
|
if model_input.inputs_embeds is None:
|
|
assert model_input.input_tokens is not None
|
|
graph_batch_size = model_input.input_tokens.shape[0]
|
|
model_executable = (
|
|
self.graph_runners[model_input.virtual_engine][(
|
|
graph_batch_size, False)])
|
|
else:
|
|
graph_batch_size = model_input.inputs_embeds.shape[0]
|
|
model_executable = (
|
|
self.graph_runners[model_input.virtual_engine][(
|
|
graph_batch_size, True)])
|
|
else:
|
|
model_executable = self.model
|
|
|
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
|
seqlen_agnostic_kwargs = {
|
|
"finished_requests_ids": model_input.finished_requests_ids,
|
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
|
} if self.has_inner_state else {}
|
|
if (self.observability_config is not None
|
|
and self.observability_config.collect_model_forward_time):
|
|
model_forward_start = torch.cuda.Event(enable_timing=True)
|
|
model_forward_end = torch.cuda.Event(enable_timing=True)
|
|
model_forward_start.record()
|
|
|
|
cross_enc_kwargs = {}
|
|
if model_input.token_types is not None:
|
|
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
|
|
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
|
virtual_engine):
|
|
hidden_or_intermediate_states = model_executable(
|
|
input_ids=model_input.input_tokens,
|
|
positions=model_input.input_positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
|
device=self.device),
|
|
**cross_enc_kwargs,
|
|
**seqlen_agnostic_kwargs)
|
|
|
|
if (self.observability_config is not None
|
|
and self.observability_config.collect_model_forward_time):
|
|
model_forward_end.record()
|
|
|
|
# Only perform pooling in the last pipeline stage.
|
|
if not get_pp_group().is_last_rank:
|
|
if (self.is_driver_worker
|
|
and hidden_or_intermediate_states is not None
|
|
and isinstance(hidden_or_intermediate_states,
|
|
IntermediateTensors)
|
|
and self.observability_config is not None
|
|
and self.observability_config.collect_model_forward_time):
|
|
model_forward_end.synchronize()
|
|
model_forward_time = model_forward_start.elapsed_time(
|
|
model_forward_end)
|
|
orig_model_forward_time = 0.0
|
|
if intermediate_tensors is not None:
|
|
orig_model_forward_time = intermediate_tensors.tensors.get(
|
|
"model_forward_time", torch.tensor(0.0)).item()
|
|
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
|
torch.tensor(model_forward_time + orig_model_forward_time))
|
|
return hidden_or_intermediate_states
|
|
|
|
# Only perform pooling in the driver worker.
|
|
if not self.is_driver_worker:
|
|
return []
|
|
|
|
return [
|
|
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
|
pooling_metadata=model_input.pooling_metadata)
|
|
]
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
self,
|
|
tensor_dict: Dict[str,
|
|
Any]) -> ModelInputForGPUWithPoolingMetadata:
|
|
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
|
tensor_dict,
|
|
attn_backend=self.attn_backend,
|
|
)
|
|
|
|
def prepare_model_input(
|
|
self,
|
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
virtual_engine: int = 0,
|
|
finished_requests_ids: Optional[List[str]] = None
|
|
) -> ModelInputForGPUWithPoolingMetadata:
|
|
assert seq_group_metadata_list is not None
|
|
model_input = self._prepare_model_input_tensors(
|
|
seq_group_metadata_list, finished_requests_ids)
|
|
# Prepare PoolingMetadata.
|
|
assert model_input.seq_lens is not None
|
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
|
model_input.seq_lens)
|
|
|
|
return dataclasses.replace(model_input,
|
|
pooling_metadata=pooling_metadata)
|
|
|
|
def _prepare_pooling(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
prompt_lens: List[int],
|
|
) -> PoolingMetadata:
|
|
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
|
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
pooling_params = seq_group_metadata.pooling_params
|
|
seq_groups.append((seq_ids, pooling_params))
|
|
|
|
seq_data: Dict[int, SequenceData] = {}
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
seq_data.update(seq_group_metadata.seq_data)
|
|
|
|
pooling_metadata = PoolingMetadata(
|
|
seq_groups=seq_groups,
|
|
seq_data=seq_data,
|
|
prompt_lens=prompt_lens,
|
|
)
|
|
|
|
return pooling_metadata
|