[Misc] Fix backward compatibility from #23030 (#23070)

Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Roger Wang 2025-08-17 23:33:29 -07:00 committed by GitHub
parent 08d5f7113a
commit 89657a557c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from .inputs import MultiModalKwargs, PlaceholderRange
from .inputs import MultiModalKwargs, NestedTensors, PlaceholderRange
_T = TypeVar("_T")
@ -56,7 +56,8 @@ class MultiModalPlaceholderMap:
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
) -> tuple[dict[str, NestedTensors], dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
@ -99,7 +100,7 @@ class MultiModalPlaceholderMap:
seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders:
return MultiModalKwargs(), {}
return MultiModalKwargs().get_data(), {}
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
@ -116,6 +117,8 @@ class MultiModalPlaceholderMap:
placeholder_maps[modality] = placeholder_map
seq_mm_data = seq_mm_data if isinstance(
seq_mm_data, dict) else seq_mm_data.get_data()
return seq_mm_data, placeholder_maps
def append_items_from_seq_group(

View File

@ -664,7 +664,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def modality(self) -> str:
return self._modality
def get_data(self) -> Mapping[str, NestedTensors]:
def get_data(self) -> dict[str, NestedTensors]:
return {key: elem.data for key, elem in self.items()}
@ -720,7 +720,7 @@ class MultiModalKwargs:
items_by_modality = full_groupby(items, key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)
self._data: Optional[Mapping[str, NestedTensors]] = None
self._data: Optional[dict[str, NestedTensors]] = None
@property
def modalities(self):
@ -883,7 +883,7 @@ class MultiModalKwargs:
def get_data(self,
*,
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
pin_memory: bool = False) -> dict[str, NestedTensors]:
if self._data is not None:
return self._data

View File

@ -22,6 +22,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.multimodal.inputs import NestedTensors
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
@ -978,7 +979,8 @@ class SequenceGroupMetadata(
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[MultiModalKwargs] = None
multi_modal_data: Optional[Union[MultiModalKwargs,
dict[str, "NestedTensors"]]] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[list[int]] = None