mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 18:10:53 +08:00
Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
parent
08d5f7113a
commit
89657a557c
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
|
||||
if TYPE_CHECKING:
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
from .inputs import MultiModalKwargs, PlaceholderRange
|
||||
from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -56,7 +56,8 @@ class MultiModalPlaceholderMap:
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
|
||||
) -> tuple[dict[str, NestedTensors], dict[str,
|
||||
"MultiModalPlaceholderMap"]]:
|
||||
"""
|
||||
Returns the multi-modal items that intersect with the portion of a
|
||||
prompt (``seq_group``) represented by ``positions``, as well as a
|
||||
@ -99,7 +100,7 @@ class MultiModalPlaceholderMap:
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
if not seq_mm_data or not seq_mm_placeholders:
|
||||
return MultiModalKwargs(), {}
|
||||
return MultiModalKwargs().get_data(), {}
|
||||
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
@ -116,6 +117,8 @@ class MultiModalPlaceholderMap:
|
||||
|
||||
placeholder_maps[modality] = placeholder_map
|
||||
|
||||
seq_mm_data = seq_mm_data if isinstance(
|
||||
seq_mm_data, dict) else seq_mm_data.get_data()
|
||||
return seq_mm_data, placeholder_maps
|
||||
|
||||
def append_items_from_seq_group(
|
||||
|
||||
@ -664,7 +664,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
def modality(self) -> str:
|
||||
return self._modality
|
||||
|
||||
def get_data(self) -> Mapping[str, NestedTensors]:
|
||||
def get_data(self) -> dict[str, NestedTensors]:
|
||||
return {key: elem.data for key, elem in self.items()}
|
||||
|
||||
|
||||
@ -720,7 +720,7 @@ class MultiModalKwargs:
|
||||
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||
self._items_by_modality = dict(items_by_modality)
|
||||
|
||||
self._data: Optional[Mapping[str, NestedTensors]] = None
|
||||
self._data: Optional[dict[str, NestedTensors]] = None
|
||||
|
||||
@property
|
||||
def modalities(self):
|
||||
@ -883,7 +883,7 @@ class MultiModalKwargs:
|
||||
|
||||
def get_data(self,
|
||||
*,
|
||||
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
|
||||
pin_memory: bool = False) -> dict[str, NestedTensors]:
|
||||
if self._data is not None:
|
||||
return self._data
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorOutput)
|
||||
|
||||
@ -978,7 +979,8 @@ class SequenceGroupMetadata(
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
token_type_ids: Optional[list[int]] = None
|
||||
multi_modal_data: Optional[MultiModalKwargs] = None
|
||||
multi_modal_data: Optional[Union[MultiModalKwargs,
|
||||
dict[str, "NestedTensors"]]] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[list[int]] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user