diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index e07b73bd257d..2149f05b6af0 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -25,7 +25,7 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]): def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs.from_items([ + return MultiModalKwargs([ _dummy_item(modality, size_by_key) for modality, size_by_key in size_by_key_modality.items() ]) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 0ab4e0bf59cf..586276ee08ae 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -100,38 +100,6 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): - d = { - "foo": - torch.zeros(20000, dtype=torch.float16), - "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], - "baz": [ - torch.rand((256), dtype=torch.float16), - [ - torch.rand((1, 12), dtype=torch.float32), - torch.rand((3, 5, 7), dtype=torch.float64), - ], [torch.rand((4, 4), dtype=torch.float16)] - ], - } - - # pack mm kwargs into a mock request so that it can be decoded properly - req = MyRequest(mm=[MultiModalKwargs(d)]) - - encoder = MsgpackEncoder() - decoder = MsgpackDecoder(MyRequest) - - encoded = encoder.encode(req) - - assert len(encoded) == 6 - - total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - - # expected total encoding length, should be 44559, +-20 for minor changes - assert 44539 <= total_len <= 44579 - decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] - assert all(nested_equal(d[k], decoded[k]) for k in d) - - -def test_multimodal_items_by_modality(): e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()) @@ -151,7 +119,7 @@ def test_multimodal_items_by_modality(): audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) - mm = MultiModalKwargs.from_items([audio, video, image]) + mm = MultiModalKwargs([audio, video, image]) # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest([mm]) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index dc3236508348..ef146fdfbf97 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -240,6 +240,6 @@ class InputRegistry: return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data, + multi_modal_data=dec_data.multi_modal_data.get_data(), multi_modal_placeholders=dec_data.multi_modal_placeholders, ) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 20f423cc7603..68488829071f 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -136,7 +136,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), + mm_kwargs=MultiModalKwargs(multimodal_kwargs_items), mm_hashes=None, mm_placeholders=mm_placeholders, ) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7188ed14c573..ef8f1b2e17b4 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -99,7 +99,7 @@ class MultiModalPlaceholderMap: seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs({}), {} + return MultiModalKwargs(), {} placeholder_maps = dict[str, MultiModalPlaceholderMap]() diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 6074a4d54f22..8c4136e06f81 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -46,7 +46,7 @@ class MultiModalCache: ) -> int: # MultiModalKwargs is not a subclass of dict if isinstance(leaf, MultiModalKwargs): - return cls.get_item_size(leaf.data, debug=debug) + return cls.get_item_size(leaf.get_data(), debug=debug) # MultiModalKwargsItem is not a subclass of dict if isinstance(leaf, MultiModalKwargsItem): diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index a33ce146995d..d3f57cf5338d 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -653,7 +653,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): def from_elems(elems: Sequence[MultiModalFieldElem]): return MultiModalKwargsItem({elem.key: elem for elem in elems}) - def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None: + def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None: super().__init__(data) modalities = {elem.modality for elem in self.data.values()} @@ -668,9 +668,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): return {key: elem.data for key, elem in self.items()} -# NOTE: UserDict is for V0 compatibility. -# V1 should access individual items via `get_item`. -class MultiModalKwargs(UserDict[str, NestedTensors]): +class MultiModalKwargs: """ A dictionary that represents the keyword arguments to [`torch.nn.Module.forward`][]. @@ -714,40 +712,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): elems = [v[item_idx] for v in elems_in_modality.values()] items.append(MultiModalKwargsItem.from_elems(elems)) - return MultiModalKwargs.from_items(items) + return MultiModalKwargs(items) - @staticmethod - def from_items( - items: Sequence[MultiModalKwargsItem], - *, - pin_memory: bool = False, - ): - """Construct a new - [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] - from multiple items.""" - elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for item in items: - for key, elem in item.items(): - elems_by_key[key].append(elem) + def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None: + super().__init__() - data = { - key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 - } - - return MultiModalKwargs(data, items=items) - - def __init__( - self, - data: Mapping[str, NestedTensors], - *, - items: Optional[Sequence[MultiModalKwargsItem]] = None, - ) -> None: - super().__init__(data) - - items_by_modality = full_groupby(items or [], key=lambda x: x.modality) + items_by_modality = full_groupby(items, key=lambda x: x.modality) self._items_by_modality = dict(items_by_modality) + self._data: Optional[Mapping[str, NestedTensors]] = None + @property def modalities(self): return self._items_by_modality.keys() @@ -839,22 +813,41 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return cast(BatchedTensorInputs, json_mapped) - def __delitem__(self, key: str) -> None: - super().__delitem__(key) + def keys(self): + return self.get_data().keys() + + def values(self): + return self.get_data().values() + + def items(self): + return self.get_data().items() + + def get(self, key: str, /, default=None): + return self.get_data().get(key, default) + + def pop(self, key: str, *args, **kwargs): + data = dict(self.get_data()) + res = data.pop(key, *args, **kwargs) for items in self._items_by_modality.values(): for item in items: - item.pop(key, None) + item.pop(key, *args, **kwargs) + + self._data = None + + return res + + def __iter__(self): + return iter(self.get_data()) + + def __getitem__(self, key: str): + return self.get_data()[key] def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - if self._items_by_modality != other._items_by_modality: - return False - ks = self.keys() - return (ks == other.keys() - and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + return self._items_by_modality == other._items_by_modality def _validate_modality(self, method_name: str, modality: str) -> None: if not self._items_by_modality: @@ -888,6 +881,25 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): self._validate_modality("get_items", modality) return self._items_by_modality[modality] + def get_data(self, + *, + pin_memory: bool = False) -> Mapping[str, NestedTensors]: + if self._data is not None: + return self._data + + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for items in self._items_by_modality.values(): + for item in items: + for key, elem in item.items(): + elems_by_key[key].append(elem) + + data = { + key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() if len(elems) > 0 + } + self._data = data + return data + MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 38c5d5d99f63..4684bf6f3d83 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1480,7 +1480,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_missing_kwargs=mm_missing_kwargs, ) - mm_kwargs = MultiModalKwargs.from_items([ + mm_kwargs = MultiModalKwargs([ item for cache_items in mm_cache_items_merged.values() for item in cache_items ]) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index f914d0dc6c5e..a80f09bb1927 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -402,12 +402,14 @@ def group_mm_kwargs_by_modality( for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # mm_kwargs_group = MultiModalKwargs.from_items(items_lst, - # pin_memory=pin_memory) + # mm_kwargs_group = MultiModalKwargs(items_lst) \ + # .get_data(pin_memory=pin_memory) # if device is not None: - # mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device), - # mm_kwargs_group.data) + # mm_kwargs_group = json_map_leaves( + # lambda x: x.to(device=device), + # mm_kwargs_group, + # ) # TODO: Once V0 is removed, we can use the merging logic above # to avoid creating an extra batch dimension (except for fields @@ -415,7 +417,7 @@ def group_mm_kwargs_by_modality( # We will also need to update each model to remove `flatten_bn`. mm_kwargs_group = MultiModalKwargs.as_kwargs( MultiModalKwargs.batch( - [MultiModalKwargs.from_items([item]) for item in items_lst], + [MultiModalKwargs([item]) for item in items_lst], pin_memory=pin_memory, ), device=device, diff --git a/vllm/sequence.py b/vllm/sequence.py index cbe63f8d1d4e..b3be10b6bb61 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -524,7 +524,7 @@ class Sequence: if self.inputs["type"] == "multimodal": return self.inputs["mm_kwargs"] - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -780,7 +780,7 @@ class SequenceGroup: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return MultiModalKwargs({}) + return MultiModalKwargs() @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3f0fad8a64d0..2857d8ef4290 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -117,16 +117,9 @@ class MsgpackEncoder: return self._encode_mm_item(obj) if isinstance(obj, MultiModalKwargs): - mm: MultiModalKwargs = obj - if not mm.modalities: - # just return the main dict if there are no modalities. - return dict(mm) - - # ignore the main dict, it will be re-indexed. - # Any tensors *not* indexed by modality will be ignored. return [ self._encode_mm_item(item) - for itemlist in mm._items_by_modality.values() + for itemlist in obj._items_by_modality.values() for item in itemlist ] @@ -268,13 +261,7 @@ class MsgpackDecoder: if issubclass(t, MultiModalKwargsItem): return self._decode_mm_item(obj) if issubclass(t, MultiModalKwargs): - if isinstance(obj, list): - return MultiModalKwargs.from_items( - self._decode_mm_items(obj)) - return MultiModalKwargs({ - k: self._decode_nested_tensors(v) - for k, v in obj.items() - }) + return MultiModalKwargs(self._decode_mm_items(obj)) if t is UtilityResult: return self._decode_utility_result(obj) return obj diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e718d9d5e0fb..3d4cf27a6ccf 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -58,7 +58,7 @@ class CachedRequestState: @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " "removed in v0.13. Please use `mm_kwargs` instead.") def mm_inputs(self) -> list[MultiModalKwargs]: - return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs] + return [MultiModalKwargs([item]) for item in self.mm_kwargs] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: