diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ef8f1b2e17b47..c4bb8d56ce3eb 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar if TYPE_CHECKING: from vllm.sequence import SequenceGroupMetadata -from .inputs import MultiModalKwargs, PlaceholderRange +from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange _T = TypeVar("_T") @@ -56,7 +56,8 @@ class MultiModalPlaceholderMap: @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: + ) -> tuple[dict[str, NestedTensors], dict[str, + "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a @@ -99,7 +100,7 @@ class MultiModalPlaceholderMap: seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs(), {} + return MultiModalKwargs().get_data(), {} placeholder_maps = dict[str, MultiModalPlaceholderMap]() @@ -116,6 +117,8 @@ class MultiModalPlaceholderMap: placeholder_maps[modality] = placeholder_map + seq_mm_data = seq_mm_data if isinstance( + seq_mm_data, dict) else seq_mm_data.get_data() return seq_mm_data, placeholder_maps def append_items_from_seq_group( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index d3f57cf5338d5..3e0bfce59c5fe 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -664,7 +664,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): def modality(self) -> str: return self._modality - def get_data(self) -> Mapping[str, NestedTensors]: + def get_data(self) -> dict[str, NestedTensors]: return {key: elem.data for key, elem in self.items()} @@ -720,7 +720,7 @@ class MultiModalKwargs: items_by_modality = full_groupby(items, key=lambda x: x.modality) self._items_by_modality = dict(items_by_modality) - self._data: Optional[Mapping[str, NestedTensors]] = None + self._data: Optional[dict[str, NestedTensors]] = None @property def modalities(self): @@ -883,7 +883,7 @@ class MultiModalKwargs: def get_data(self, *, - pin_memory: bool = False) -> Mapping[str, NestedTensors]: + pin_memory: bool = False) -> dict[str, NestedTensors]: if self._data is not None: return self._data diff --git a/vllm/sequence.py b/vllm/sequence.py index b3be10b6bb612..2cb254381eff4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -22,6 +22,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: + from vllm.multimodal.inputs import NestedTensors from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) @@ -978,7 +979,8 @@ class SequenceGroupMetadata( state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) token_type_ids: Optional[list[int]] = None - multi_modal_data: Optional[MultiModalKwargs] = None + multi_modal_data: Optional[Union[MultiModalKwargs, + dict[str, "NestedTensors"]]] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[list[int]] = None