mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[Refactor] Defer tensor data construction in MultiModalKwargs (#23030)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
94096a47c9
commit
5c32143b9d
@ -25,7 +25,7 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
|||||||
|
|
||||||
|
|
||||||
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||||
return MultiModalKwargs.from_items([
|
return MultiModalKwargs([
|
||||||
_dummy_item(modality, size_by_key)
|
_dummy_item(modality, size_by_key)
|
||||||
for modality, size_by_key in size_by_key_modality.items()
|
for modality, size_by_key in size_by_key_modality.items()
|
||||||
])
|
])
|
||||||
|
|||||||
@ -100,38 +100,6 @@ class MyRequest(msgspec.Struct):
|
|||||||
|
|
||||||
|
|
||||||
def test_multimodal_kwargs():
|
def test_multimodal_kwargs():
|
||||||
d = {
|
|
||||||
"foo":
|
|
||||||
torch.zeros(20000, dtype=torch.float16),
|
|
||||||
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
|
|
||||||
"baz": [
|
|
||||||
torch.rand((256), dtype=torch.float16),
|
|
||||||
[
|
|
||||||
torch.rand((1, 12), dtype=torch.float32),
|
|
||||||
torch.rand((3, 5, 7), dtype=torch.float64),
|
|
||||||
], [torch.rand((4, 4), dtype=torch.float16)]
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
|
||||||
req = MyRequest(mm=[MultiModalKwargs(d)])
|
|
||||||
|
|
||||||
encoder = MsgpackEncoder()
|
|
||||||
decoder = MsgpackDecoder(MyRequest)
|
|
||||||
|
|
||||||
encoded = encoder.encode(req)
|
|
||||||
|
|
||||||
assert len(encoded) == 6
|
|
||||||
|
|
||||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
|
||||||
|
|
||||||
# expected total encoding length, should be 44559, +-20 for minor changes
|
|
||||||
assert 44539 <= total_len <= 44579
|
|
||||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
|
||||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multimodal_items_by_modality():
|
|
||||||
e1 = MultiModalFieldElem("audio", "a0",
|
e1 = MultiModalFieldElem("audio", "a0",
|
||||||
torch.zeros(1000, dtype=torch.bfloat16),
|
torch.zeros(1000, dtype=torch.bfloat16),
|
||||||
MultiModalBatchedField())
|
MultiModalBatchedField())
|
||||||
@ -151,7 +119,7 @@ def test_multimodal_items_by_modality():
|
|||||||
audio = MultiModalKwargsItem.from_elems([e1])
|
audio = MultiModalKwargsItem.from_elems([e1])
|
||||||
video = MultiModalKwargsItem.from_elems([e2])
|
video = MultiModalKwargsItem.from_elems([e2])
|
||||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||||
mm = MultiModalKwargs.from_items([audio, video, image])
|
mm = MultiModalKwargs([audio, video, image])
|
||||||
|
|
||||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||||
req = MyRequest([mm])
|
req = MyRequest([mm])
|
||||||
|
|||||||
@ -240,6 +240,6 @@ class InputRegistry:
|
|||||||
|
|
||||||
return DummyData(
|
return DummyData(
|
||||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
||||||
multi_modal_data=dec_data.multi_modal_data,
|
multi_modal_data=dec_data.multi_modal_data.get_data(),
|
||||||
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
type="multimodal",
|
type="multimodal",
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=[1],
|
prompt_token_ids=[1],
|
||||||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
|
mm_kwargs=MultiModalKwargs(multimodal_kwargs_items),
|
||||||
mm_hashes=None,
|
mm_hashes=None,
|
||||||
mm_placeholders=mm_placeholders,
|
mm_placeholders=mm_placeholders,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class MultiModalPlaceholderMap:
|
|||||||
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
seq_mm_placeholders = seq_group.multi_modal_placeholders
|
||||||
|
|
||||||
if not seq_mm_data or not seq_mm_placeholders:
|
if not seq_mm_data or not seq_mm_placeholders:
|
||||||
return MultiModalKwargs({}), {}
|
return MultiModalKwargs(), {}
|
||||||
|
|
||||||
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
|
||||||
|
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class MultiModalCache:
|
|||||||
) -> int:
|
) -> int:
|
||||||
# MultiModalKwargs is not a subclass of dict
|
# MultiModalKwargs is not a subclass of dict
|
||||||
if isinstance(leaf, MultiModalKwargs):
|
if isinstance(leaf, MultiModalKwargs):
|
||||||
return cls.get_item_size(leaf.data, debug=debug)
|
return cls.get_item_size(leaf.get_data(), debug=debug)
|
||||||
|
|
||||||
# MultiModalKwargsItem is not a subclass of dict
|
# MultiModalKwargsItem is not a subclass of dict
|
||||||
if isinstance(leaf, MultiModalKwargsItem):
|
if isinstance(leaf, MultiModalKwargsItem):
|
||||||
|
|||||||
@ -653,7 +653,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||||
|
|
||||||
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
|
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
|
||||||
super().__init__(data)
|
super().__init__(data)
|
||||||
|
|
||||||
modalities = {elem.modality for elem in self.data.values()}
|
modalities = {elem.modality for elem in self.data.values()}
|
||||||
@ -668,9 +668,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
return {key: elem.data for key, elem in self.items()}
|
return {key: elem.data for key, elem in self.items()}
|
||||||
|
|
||||||
|
|
||||||
# NOTE: UserDict is for V0 compatibility.
|
class MultiModalKwargs:
|
||||||
# V1 should access individual items via `get_item`.
|
|
||||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|
||||||
"""
|
"""
|
||||||
A dictionary that represents the keyword arguments to
|
A dictionary that represents the keyword arguments to
|
||||||
[`torch.nn.Module.forward`][].
|
[`torch.nn.Module.forward`][].
|
||||||
@ -714,40 +712,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
elems = [v[item_idx] for v in elems_in_modality.values()]
|
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||||
items.append(MultiModalKwargsItem.from_elems(elems))
|
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||||
|
|
||||||
return MultiModalKwargs.from_items(items)
|
return MultiModalKwargs(items)
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
|
||||||
def from_items(
|
super().__init__()
|
||||||
items: Sequence[MultiModalKwargsItem],
|
|
||||||
*,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
):
|
|
||||||
"""Construct a new
|
|
||||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
|
|
||||||
from multiple items."""
|
|
||||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
|
||||||
for item in items:
|
|
||||||
for key, elem in item.items():
|
|
||||||
elems_by_key[key].append(elem)
|
|
||||||
|
|
||||||
data = {
|
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
|
||||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return MultiModalKwargs(data, items=items)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data: Mapping[str, NestedTensors],
|
|
||||||
*,
|
|
||||||
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(data)
|
|
||||||
|
|
||||||
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
|
||||||
self._items_by_modality = dict(items_by_modality)
|
self._items_by_modality = dict(items_by_modality)
|
||||||
|
|
||||||
|
self._data: Optional[Mapping[str, NestedTensors]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def modalities(self):
|
def modalities(self):
|
||||||
return self._items_by_modality.keys()
|
return self._items_by_modality.keys()
|
||||||
@ -839,22 +813,41 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
|
|
||||||
return cast(BatchedTensorInputs, json_mapped)
|
return cast(BatchedTensorInputs, json_mapped)
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
def keys(self):
|
||||||
super().__delitem__(key)
|
return self.get_data().keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self.get_data().values()
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.get_data().items()
|
||||||
|
|
||||||
|
def get(self, key: str, /, default=None):
|
||||||
|
return self.get_data().get(key, default)
|
||||||
|
|
||||||
|
def pop(self, key: str, *args, **kwargs):
|
||||||
|
data = dict(self.get_data())
|
||||||
|
res = data.pop(key, *args, **kwargs)
|
||||||
|
|
||||||
for items in self._items_by_modality.values():
|
for items in self._items_by_modality.values():
|
||||||
for item in items:
|
for item in items:
|
||||||
item.pop(key, None)
|
item.pop(key, *args, **kwargs)
|
||||||
|
|
||||||
|
self._data = None
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.get_data())
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return self.get_data()[key]
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return False
|
return False
|
||||||
if self._items_by_modality != other._items_by_modality:
|
|
||||||
return False
|
|
||||||
|
|
||||||
ks = self.keys()
|
return self._items_by_modality == other._items_by_modality
|
||||||
return (ks == other.keys()
|
|
||||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
|
||||||
|
|
||||||
def _validate_modality(self, method_name: str, modality: str) -> None:
|
def _validate_modality(self, method_name: str, modality: str) -> None:
|
||||||
if not self._items_by_modality:
|
if not self._items_by_modality:
|
||||||
@ -888,6 +881,25 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
self._validate_modality("get_items", modality)
|
self._validate_modality("get_items", modality)
|
||||||
return self._items_by_modality[modality]
|
return self._items_by_modality[modality]
|
||||||
|
|
||||||
|
def get_data(self,
|
||||||
|
*,
|
||||||
|
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
|
||||||
|
if self._data is not None:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||||
|
for items in self._items_by_modality.values():
|
||||||
|
for item in items:
|
||||||
|
for key, elem in item.items():
|
||||||
|
elems_by_key[key].append(elem)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||||
|
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||||
|
}
|
||||||
|
self._data = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1480,7 +1480,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_missing_kwargs=mm_missing_kwargs,
|
mm_missing_kwargs=mm_missing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_kwargs = MultiModalKwargs.from_items([
|
mm_kwargs = MultiModalKwargs([
|
||||||
item for cache_items in mm_cache_items_merged.values()
|
item for cache_items in mm_cache_items_merged.values()
|
||||||
for item in cache_items
|
for item in cache_items
|
||||||
])
|
])
|
||||||
|
|||||||
@ -402,12 +402,14 @@ def group_mm_kwargs_by_modality(
|
|||||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||||
items_lst = list(items)
|
items_lst = list(items)
|
||||||
|
|
||||||
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
|
# mm_kwargs_group = MultiModalKwargs(items_lst) \
|
||||||
# pin_memory=pin_memory)
|
# .get_data(pin_memory=pin_memory)
|
||||||
|
|
||||||
# if device is not None:
|
# if device is not None:
|
||||||
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
|
# mm_kwargs_group = json_map_leaves(
|
||||||
# mm_kwargs_group.data)
|
# lambda x: x.to(device=device),
|
||||||
|
# mm_kwargs_group,
|
||||||
|
# )
|
||||||
|
|
||||||
# TODO: Once V0 is removed, we can use the merging logic above
|
# TODO: Once V0 is removed, we can use the merging logic above
|
||||||
# to avoid creating an extra batch dimension (except for fields
|
# to avoid creating an extra batch dimension (except for fields
|
||||||
@ -415,7 +417,7 @@ def group_mm_kwargs_by_modality(
|
|||||||
# We will also need to update each model to remove `flatten_bn`.
|
# We will also need to update each model to remove `flatten_bn`.
|
||||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||||
MultiModalKwargs.batch(
|
MultiModalKwargs.batch(
|
||||||
[MultiModalKwargs.from_items([item]) for item in items_lst],
|
[MultiModalKwargs([item]) for item in items_lst],
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
),
|
),
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@ -524,7 +524,7 @@ class Sequence:
|
|||||||
if self.inputs["type"] == "multimodal":
|
if self.inputs["type"] == "multimodal":
|
||||||
return self.inputs["mm_kwargs"]
|
return self.inputs["mm_kwargs"]
|
||||||
|
|
||||||
return MultiModalKwargs({})
|
return MultiModalKwargs()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||||
@ -780,7 +780,7 @@ class SequenceGroup:
|
|||||||
return self.first_seq.multi_modal_data
|
return self.first_seq.multi_modal_data
|
||||||
elif self.encoder_seq is not None:
|
elif self.encoder_seq is not None:
|
||||||
return self.encoder_seq.multi_modal_data
|
return self.encoder_seq.multi_modal_data
|
||||||
return MultiModalKwargs({})
|
return MultiModalKwargs()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||||
|
|||||||
@ -117,16 +117,9 @@ class MsgpackEncoder:
|
|||||||
return self._encode_mm_item(obj)
|
return self._encode_mm_item(obj)
|
||||||
|
|
||||||
if isinstance(obj, MultiModalKwargs):
|
if isinstance(obj, MultiModalKwargs):
|
||||||
mm: MultiModalKwargs = obj
|
|
||||||
if not mm.modalities:
|
|
||||||
# just return the main dict if there are no modalities.
|
|
||||||
return dict(mm)
|
|
||||||
|
|
||||||
# ignore the main dict, it will be re-indexed.
|
|
||||||
# Any tensors *not* indexed by modality will be ignored.
|
|
||||||
return [
|
return [
|
||||||
self._encode_mm_item(item)
|
self._encode_mm_item(item)
|
||||||
for itemlist in mm._items_by_modality.values()
|
for itemlist in obj._items_by_modality.values()
|
||||||
for item in itemlist
|
for item in itemlist
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -268,13 +261,7 @@ class MsgpackDecoder:
|
|||||||
if issubclass(t, MultiModalKwargsItem):
|
if issubclass(t, MultiModalKwargsItem):
|
||||||
return self._decode_mm_item(obj)
|
return self._decode_mm_item(obj)
|
||||||
if issubclass(t, MultiModalKwargs):
|
if issubclass(t, MultiModalKwargs):
|
||||||
if isinstance(obj, list):
|
return MultiModalKwargs(self._decode_mm_items(obj))
|
||||||
return MultiModalKwargs.from_items(
|
|
||||||
self._decode_mm_items(obj))
|
|
||||||
return MultiModalKwargs({
|
|
||||||
k: self._decode_nested_tensors(v)
|
|
||||||
for k, v in obj.items()
|
|
||||||
})
|
|
||||||
if t is UtilityResult:
|
if t is UtilityResult:
|
||||||
return self._decode_utility_result(obj)
|
return self._decode_utility_result(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class CachedRequestState:
|
|||||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||||
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
|
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
|
||||||
|
|
||||||
def get_token_id(self, idx: int) -> int:
|
def get_token_id(self, idx: int) -> int:
|
||||||
if idx < self.num_prompt_tokens:
|
if idx < self.num_prompt_tokens:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user