[Refactor] Defer tensor data construction in MultiModalKwargs (#23030)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-17 12:05:50 +08:00 committed by GitHub
parent 94096a47c9
commit 5c32143b9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 73 additions and 104 deletions

View File

@ -25,7 +25,7 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
return MultiModalKwargs([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])

View File

@ -100,38 +100,6 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])
encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)
encoded = encoder.encode(req)
assert len(encoded) == 6
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44559, +-20 for minor changes
assert 44539 <= total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
@ -151,7 +119,7 @@ def test_multimodal_items_by_modality():
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])
mm = MultiModalKwargs([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])

View File

@ -240,6 +240,6 @@ class InputRegistry:
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
multi_modal_data=dec_data.multi_modal_data,
multi_modal_data=dec_data.multi_modal_data.get_data(),
multi_modal_placeholders=dec_data.multi_modal_placeholders,
)

View File

@ -136,7 +136,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_kwargs=MultiModalKwargs(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders=mm_placeholders,
)

View File

@ -99,7 +99,7 @@ class MultiModalPlaceholderMap:
seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders:
return MultiModalKwargs({}), {}
return MultiModalKwargs(), {}
placeholder_maps = dict[str, MultiModalPlaceholderMap]()

View File

@ -46,7 +46,7 @@ class MultiModalCache:
) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data, debug=debug)
return cls.get_item_size(leaf.get_data(), debug=debug)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):

View File

@ -653,7 +653,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
super().__init__(data)
modalities = {elem.modality for elem in self.data.values()}
@ -668,9 +668,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()}
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
class MultiModalKwargs:
"""
A dictionary that represents the keyword arguments to
[`torch.nn.Module.forward`][].
@ -714,40 +712,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs.from_items(items)
return MultiModalKwargs(items)
@staticmethod
def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
"""Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
super().__init__()
data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
}
return MultiModalKwargs(data, items=items)
def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items: Optional[Sequence[MultiModalKwargsItem]] = None,
) -> None:
super().__init__(data)
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
items_by_modality = full_groupby(items, key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)
self._data: Optional[Mapping[str, NestedTensors]] = None
@property
def modalities(self):
return self._items_by_modality.keys()
@ -839,22 +813,41 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
def keys(self):
return self.get_data().keys()
def values(self):
return self.get_data().values()
def items(self):
return self.get_data().items()
def get(self, key: str, /, default=None):
return self.get_data().get(key, default)
def pop(self, key: str, *args, **kwargs):
data = dict(self.get_data())
res = data.pop(key, *args, **kwargs)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
item.pop(key, *args, **kwargs)
self._data = None
return res
def __iter__(self):
return iter(self.get_data())
def __getitem__(self, key: str):
return self.get_data()[key]
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_modality != other._items_by_modality:
return False
ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
return self._items_by_modality == other._items_by_modality
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
@ -888,6 +881,25 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]
def get_data(self,
*,
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
if self._data is not None:
return self._data
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self._items_by_modality.values():
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
}
self._data = data
return data
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
"""

View File

@ -1480,7 +1480,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs=mm_missing_kwargs,
)
mm_kwargs = MultiModalKwargs.from_items([
mm_kwargs = MultiModalKwargs([
item for cache_items in mm_cache_items_merged.values()
for item in cache_items
])

View File

@ -402,12 +402,14 @@ def group_mm_kwargs_by_modality(
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
# pin_memory=pin_memory)
# mm_kwargs_group = MultiModalKwargs(items_lst) \
# .get_data(pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
# mm_kwargs_group.data)
# mm_kwargs_group = json_map_leaves(
# lambda x: x.to(device=device),
# mm_kwargs_group,
# )
# TODO: Once V0 is removed, we can use the merging logic above
# to avoid creating an extra batch dimension (except for fields
@ -415,7 +417,7 @@ def group_mm_kwargs_by_modality(
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[MultiModalKwargs.from_items([item]) for item in items_lst],
[MultiModalKwargs([item]) for item in items_lst],
pin_memory=pin_memory,
),
device=device,

View File

@ -524,7 +524,7 @@ class Sequence:
if self.inputs["type"] == "multimodal":
return self.inputs["mm_kwargs"]
return MultiModalKwargs({})
return MultiModalKwargs()
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
@ -780,7 +780,7 @@ class SequenceGroup:
return self.first_seq.multi_modal_data
elif self.encoder_seq is not None:
return self.encoder_seq.multi_modal_data
return MultiModalKwargs({})
return MultiModalKwargs()
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:

View File

@ -117,16 +117,9 @@ class MsgpackEncoder:
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
# just return the main dict if there are no modalities.
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Any tensors *not* indexed by modality will be ignored.
return [
self._encode_mm_item(item)
for itemlist in mm._items_by_modality.values()
for itemlist in obj._items_by_modality.values()
for item in itemlist
]
@ -268,13 +261,7 @@ class MsgpackDecoder:
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
self._decode_mm_items(obj))
return MultiModalKwargs({
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
return MultiModalKwargs(self._decode_mm_items(obj))
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj

View File

@ -58,7 +58,7 @@ class CachedRequestState:
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: