mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Refactor] Defer tensor data construction in MultiModalKwargs (#23030)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
94096a47c9
commit
5c32143b9d
@ -25,7 +25,7 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
|
||||
|
||||
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||
return MultiModalKwargs.from_items([
|
||||
return MultiModalKwargs([
|
||||
_dummy_item(modality, size_by_key)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
])
|
||||
|
||||
@ -100,38 +100,6 @@ class MyRequest(msgspec.Struct):
|
||||
|
||||
|
||||
def test_multimodal_kwargs():
|
||||
d = {
|
||||
"foo":
|
||||
torch.zeros(20000, dtype=torch.float16),
|
||||
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
|
||||
"baz": [
|
||||
torch.rand((256), dtype=torch.float16),
|
||||
[
|
||||
torch.rand((1, 12), dtype=torch.float32),
|
||||
torch.rand((3, 5, 7), dtype=torch.float64),
|
||||
], [torch.rand((4, 4), dtype=torch.float16)]
|
||||
],
|
||||
}
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest(mm=[MultiModalKwargs(d)])
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyRequest)
|
||||
|
||||
encoded = encoder.encode(req)
|
||||
|
||||
assert len(encoded) == 6
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 44559, +-20 for minor changes
|
||||
assert 44539 <= total_len <= 44579
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||
|
||||
|
||||
def test_multimodal_items_by_modality():
|
||||
e1 = MultiModalFieldElem("audio", "a0",
|
||||
torch.zeros(1000, dtype=torch.bfloat16),
|
||||
MultiModalBatchedField())
|
||||
@ -151,7 +119,7 @@ def test_multimodal_items_by_modality():
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
mm = MultiModalKwargs.from_items([audio, video, image])
|
||||
mm = MultiModalKwargs([audio, video, image])
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest([mm])
|
||||
|
||||
@ -240,6 +240,6 @@ class InputRegistry:
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
||||
multi_modal_data=dec_data.multi_modal_data,
|
||||
multi_modal_data=dec_data.multi_modal_data.get_data(),
|
||||
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
||||
)
|
||||
|
||||
@ -136,7 +136,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
|
||||
mm_kwargs=MultiModalKwargs(multimodal_kwargs_items),
|
||||
mm_hashes=None,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
@ -99,7 +99,7 @@ class MultiModalPlaceholderMap:
|
||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||
|
||||
if not seq_mm_data or not seq_mm_placeholders:
|
||||
return MultiModalKwargs({}), {}
|
||||
return MultiModalKwargs(), {}
|
||||
|
||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class MultiModalCache:
|
||||
) -> int:
|
||||
# MultiModalKwargs is not a subclass of dict
|
||||
if isinstance(leaf, MultiModalKwargs):
|
||||
return cls.get_item_size(leaf.data, debug=debug)
|
||||
return cls.get_item_size(leaf.get_data(), debug=debug)
|
||||
|
||||
# MultiModalKwargsItem is not a subclass of dict
|
||||
if isinstance(leaf, MultiModalKwargsItem):
|
||||
|
||||
@ -653,7 +653,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
|
||||
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
|
||||
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
modalities = {elem.modality for elem in self.data.values()}
|
||||
@ -668,9 +668,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
return {key: elem.data for key, elem in self.items()}
|
||||
|
||||
|
||||
# NOTE: UserDict is for V0 compatibility.
|
||||
# V1 should access individual items via `get_item`.
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
class MultiModalKwargs:
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
[`torch.nn.Module.forward`][].
|
||||
@ -714,40 +712,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
|
||||
return MultiModalKwargs.from_items(items)
|
||||
return MultiModalKwargs(items)
|
||||
|
||||
@staticmethod
|
||||
def from_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
"""Construct a new
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
|
||||
from multiple items."""
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for item in items:
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
|
||||
super().__init__()
|
||||
|
||||
data = {
|
||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
|
||||
return MultiModalKwargs(data, items=items)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, NestedTensors],
|
||||
*,
|
||||
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
||||
) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
||||
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||
self._items_by_modality = dict(items_by_modality)
|
||||
|
||||
self._data: Optional[Mapping[str, NestedTensors]] = None
|
||||
|
||||
@property
|
||||
def modalities(self):
|
||||
return self._items_by_modality.keys()
|
||||
@ -839,22 +813,41 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
super().__delitem__(key)
|
||||
def keys(self):
|
||||
return self.get_data().keys()
|
||||
|
||||
def values(self):
|
||||
return self.get_data().values()
|
||||
|
||||
def items(self):
|
||||
return self.get_data().items()
|
||||
|
||||
def get(self, key: str, /, default=None):
|
||||
return self.get_data().get(key, default)
|
||||
|
||||
def pop(self, key: str, *args, **kwargs):
|
||||
data = dict(self.get_data())
|
||||
res = data.pop(key, *args, **kwargs)
|
||||
|
||||
for items in self._items_by_modality.values():
|
||||
for item in items:
|
||||
item.pop(key, None)
|
||||
item.pop(key, *args, **kwargs)
|
||||
|
||||
self._data = None
|
||||
|
||||
return res
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.get_data())
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.get_data()[key]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self._items_by_modality != other._items_by_modality:
|
||||
return False
|
||||
|
||||
ks = self.keys()
|
||||
return (ks == other.keys()
|
||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
||||
return self._items_by_modality == other._items_by_modality
|
||||
|
||||
def _validate_modality(self, method_name: str, modality: str) -> None:
|
||||
if not self._items_by_modality:
|
||||
@ -888,6 +881,25 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
self._validate_modality("get_items", modality)
|
||||
return self._items_by_modality[modality]
|
||||
|
||||
def get_data(self,
|
||||
*,
|
||||
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
|
||||
if self._data is not None:
|
||||
return self._data
|
||||
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for items in self._items_by_modality.values():
|
||||
for item in items:
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
data = {
|
||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
self._data = data
|
||||
return data
|
||||
|
||||
|
||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
|
||||
@ -1480,7 +1480,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_missing_kwargs=mm_missing_kwargs,
|
||||
)
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items([
|
||||
mm_kwargs = MultiModalKwargs([
|
||||
item for cache_items in mm_cache_items_merged.values()
|
||||
for item in cache_items
|
||||
])
|
||||
|
||||
@ -402,12 +402,14 @@ def group_mm_kwargs_by_modality(
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
|
||||
# pin_memory=pin_memory)
|
||||
# mm_kwargs_group = MultiModalKwargs(items_lst) \
|
||||
# .get_data(pin_memory=pin_memory)
|
||||
|
||||
# if device is not None:
|
||||
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
|
||||
# mm_kwargs_group.data)
|
||||
# mm_kwargs_group = json_map_leaves(
|
||||
# lambda x: x.to(device=device),
|
||||
# mm_kwargs_group,
|
||||
# )
|
||||
|
||||
# TODO: Once V0 is removed, we can use the merging logic above
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
@ -415,7 +417,7 @@ def group_mm_kwargs_by_modality(
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[MultiModalKwargs.from_items([item]) for item in items_lst],
|
||||
[MultiModalKwargs([item]) for item in items_lst],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@ -524,7 +524,7 @@ class Sequence:
|
||||
if self.inputs["type"] == "multimodal":
|
||||
return self.inputs["mm_kwargs"]
|
||||
|
||||
return MultiModalKwargs({})
|
||||
return MultiModalKwargs()
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
@ -780,7 +780,7 @@ class SequenceGroup:
|
||||
return self.first_seq.multi_modal_data
|
||||
elif self.encoder_seq is not None:
|
||||
return self.encoder_seq.multi_modal_data
|
||||
return MultiModalKwargs({})
|
||||
return MultiModalKwargs()
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
|
||||
@ -117,16 +117,9 @@ class MsgpackEncoder:
|
||||
return self._encode_mm_item(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
mm: MultiModalKwargs = obj
|
||||
if not mm.modalities:
|
||||
# just return the main dict if there are no modalities.
|
||||
return dict(mm)
|
||||
|
||||
# ignore the main dict, it will be re-indexed.
|
||||
# Any tensors *not* indexed by modality will be ignored.
|
||||
return [
|
||||
self._encode_mm_item(item)
|
||||
for itemlist in mm._items_by_modality.values()
|
||||
for itemlist in obj._items_by_modality.values()
|
||||
for item in itemlist
|
||||
]
|
||||
|
||||
@ -268,13 +261,7 @@ class MsgpackDecoder:
|
||||
if issubclass(t, MultiModalKwargsItem):
|
||||
return self._decode_mm_item(obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
if isinstance(obj, list):
|
||||
return MultiModalKwargs.from_items(
|
||||
self._decode_mm_items(obj))
|
||||
return MultiModalKwargs({
|
||||
k: self._decode_nested_tensors(v)
|
||||
for k, v in obj.items()
|
||||
})
|
||||
return MultiModalKwargs(self._decode_mm_items(obj))
|
||||
if t is UtilityResult:
|
||||
return self._decode_utility_result(obj)
|
||||
return obj
|
||||
|
||||
@ -58,7 +58,7 @@ class CachedRequestState:
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
|
||||
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user