mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 01:45:00 +08:00
[Chore] Deprecate merge_by_field_config arg (#30035)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
990f806473
commit
b286a311c2
@ -20,7 +20,7 @@ from vllm.config.multimodal import (
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
@ -418,4 +418,4 @@ def _assert_inputs_equal(
|
||||
a_data.pop(key, None)
|
||||
b_data.pop(key, None)
|
||||
|
||||
assert a_data == b_data, msg
|
||||
assert batched_tensors_equal(a_data, b_data), msg
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import batched_tensors_equal
|
||||
from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend
|
||||
|
||||
from ...utils import build_model_context
|
||||
@ -103,7 +104,7 @@ def test_video_loader_consistency(
|
||||
dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"]
|
||||
assert (
|
||||
static_outputs["mm_kwargs"].get_data()
|
||||
== dynamic_outputs["mm_kwargs"].get_data()
|
||||
assert batched_tensors_equal(
|
||||
static_outputs["mm_kwargs"].get_data(),
|
||||
dynamic_outputs["mm_kwargs"].get_data(),
|
||||
)
|
||||
|
||||
@ -130,10 +130,9 @@ def create_batched_mm_kwargs(
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)["mm_kwargs"].require_data()
|
||||
items = [item for modality in supported_mm_limits for item in mm_kwargs[modality]]
|
||||
|
||||
return group_mm_kwargs_by_modality(
|
||||
items,
|
||||
merge_by_field_config=model_cls.merge_by_field_config,
|
||||
[item for modality in supported_mm_limits for item in mm_kwargs[modality]]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -85,12 +85,6 @@ def _dummy_items(
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
(
|
||||
_dummy_items(
|
||||
{"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}
|
||||
).get_data(),
|
||||
460,
|
||||
), # noqa: E501
|
||||
],
|
||||
)
|
||||
def test_cache_item_size(item, expected_size):
|
||||
@ -107,6 +101,9 @@ def test_cache_item_size(item, expected_size):
|
||||
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = item.get_data()
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
|
||||
def _create_vllm_config(
|
||||
*,
|
||||
|
||||
@ -1,91 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors):
|
||||
assert type(expected) == type(actual) # noqa: E721
|
||||
if isinstance(expected, torch.Tensor):
|
||||
assert torch.equal(expected, actual)
|
||||
else:
|
||||
for expected_item, actual_item in zip(expected, actual):
|
||||
assert_nested_tensors_equal(expected_item, actual_item)
|
||||
|
||||
|
||||
def assert_multimodal_inputs_equal(
|
||||
expected: MultiModalKwargs, actual: MultiModalKwargs
|
||||
):
|
||||
assert set(expected.keys()) == set(actual.keys())
|
||||
for key in expected:
|
||||
assert_nested_tensors_equal(expected[key], actual[key])
|
||||
|
||||
|
||||
def test_multimodal_input_batch_single_tensor():
|
||||
t = torch.rand([1, 2])
|
||||
result = MultiModalKwargs.batch([{"image": t}])
|
||||
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_tensors():
|
||||
a = torch.rand([1, 1, 2])
|
||||
b = torch.rand([1, 1, 2])
|
||||
c = torch.rand([1, 1, 2])
|
||||
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
|
||||
a = torch.rand([1, 2, 2])
|
||||
b = torch.rand([1, 3, 2])
|
||||
c = torch.rand([1, 4, 2])
|
||||
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
|
||||
|
||||
|
||||
def test_multimodal_input_batch_nested_tensors():
|
||||
a = torch.rand([2, 3])
|
||||
b = torch.rand([2, 3])
|
||||
c = torch.rand([2, 3])
|
||||
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}])
|
||||
assert_multimodal_inputs_equal(
|
||||
result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])}
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_input_batch_heterogeneous_lists():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 2, 3])
|
||||
c = torch.rand([1, 2, 3])
|
||||
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
|
||||
assert_multimodal_inputs_equal(
|
||||
result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_input_batch_multiple_batchable_lists():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 2, 3])
|
||||
c = torch.rand([1, 2, 3])
|
||||
d = torch.rand([1, 2, 3])
|
||||
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
|
||||
assert_multimodal_inputs_equal(
|
||||
result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])}
|
||||
)
|
||||
|
||||
|
||||
def test_multimodal_input_batch_mixed_stacking_depths():
|
||||
a = torch.rand([1, 2, 3])
|
||||
b = torch.rand([1, 3, 3])
|
||||
c = torch.rand([1, 4, 3])
|
||||
|
||||
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
|
||||
|
||||
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}])
|
||||
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})
|
||||
@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
@ -305,7 +305,7 @@ class DeepseekOCRMultiModalProcessor(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
|
||||
@ -78,7 +78,7 @@ class SupportsMultiModal(Protocol):
|
||||
`multimodal_config.mm_encoder_tp_mode="data"`.
|
||||
"""
|
||||
|
||||
merge_by_field_config: ClassVar[bool] = False
|
||||
merge_by_field_config: ClassVar[bool] = True
|
||||
"""
|
||||
A flag that indicates which implementation of
|
||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.model_executor.models.utils import (
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
@ -103,7 +103,7 @@ class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingIn
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
@ -52,7 +52,6 @@ from vllm.multimodal.evs import (
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
@ -849,17 +848,18 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
image_num_patches = out_mm_kwargs["image_num_patches"]
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
if "image_num_patches" in out_mm_data:
|
||||
image_num_patches = out_mm_data["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
image_num_patches = image_num_patches.tolist()
|
||||
elif "image_embeds" in out_mm_kwargs:
|
||||
elif "image_embeds" in out_mm_data:
|
||||
# to compute num_patches (similar to Qwen2-VL)
|
||||
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
||||
image_num_patches = [None] * len(out_mm_data["image_embeds"])
|
||||
else:
|
||||
image_num_patches = []
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (
|
||||
@ -153,7 +153,7 @@ class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo])
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
@ -62,7 +62,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageProcessorItems,
|
||||
@ -307,7 +307,7 @@ class PaddleOCRVLMultiModalProcessor(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
@ -40,7 +40,6 @@ from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
@ -252,6 +251,8 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder,
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -327,9 +328,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
@ -337,8 +337,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = flatten_bn(image_embeds, concat=True)
|
||||
|
||||
return PaliGemmaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
|
||||
@ -77,7 +77,7 @@ from vllm.multimodal.evs import (
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
@ -973,7 +973,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
@ -25,7 +25,6 @@ from .inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
NestedTensors,
|
||||
@ -90,7 +89,6 @@ MultiModalCacheValue: TypeAlias = (
|
||||
| MultiModalProcessorCacheItemMetadata
|
||||
| MultiModalKwargsItems
|
||||
| MultiModalKwargsItem
|
||||
| MultiModalKwargs
|
||||
| Mapping[str, NestedTensors]
|
||||
)
|
||||
|
||||
@ -108,12 +106,7 @@ class MultiModalCache:
|
||||
# These are not subclasses of dict
|
||||
if isinstance(
|
||||
leaf,
|
||||
(
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalFieldElem,
|
||||
),
|
||||
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
|
||||
):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
@ -201,8 +201,10 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||
|
||||
|
||||
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
"""Equality check between
|
||||
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
|
||||
"""
|
||||
Equality check between
|
||||
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.
|
||||
"""
|
||||
if isinstance(a, torch.Tensor):
|
||||
return isinstance(b, torch.Tensor) and torch.equal(a, b)
|
||||
elif isinstance(b, torch.Tensor):
|
||||
@ -224,10 +226,24 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
|
||||
"""
|
||||
A dictionary containing nested tensors which have been batched via
|
||||
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
|
||||
[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data].
|
||||
"""
|
||||
|
||||
|
||||
def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool:
|
||||
"""
|
||||
Equality check between
|
||||
[`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects.
|
||||
"""
|
||||
for k in a:
|
||||
if k not in b:
|
||||
return False
|
||||
if not nested_tensors_equal(a[k], b[k]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalFeatureSpec:
|
||||
"""
|
||||
@ -823,7 +839,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
||||
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
||||
def get_data(
|
||||
self,
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
cpu_fields: Set[str] = frozenset(),
|
||||
) -> BatchedTensorInputs:
|
||||
"""Construct a dictionary of keyword arguments to pass to the model."""
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for modality, items in self.items():
|
||||
for i, item in enumerate(items):
|
||||
@ -835,12 +858,23 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
return MultiModalKwargs(
|
||||
{
|
||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items()
|
||||
}
|
||||
)
|
||||
data = {
|
||||
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items()
|
||||
}
|
||||
|
||||
if device is not None:
|
||||
for k in data.keys() - cpu_fields:
|
||||
data[k] = json_map_leaves(
|
||||
(
|
||||
lambda x: x.to(device=device, non_blocking=True)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
),
|
||||
data[k],
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
MultiModalKwargsOptionalItems: TypeAlias = (
|
||||
@ -849,6 +883,7 @@ MultiModalKwargsOptionalItems: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.")
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
@ -882,91 +917,6 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(
|
||||
nested_tensors: NestedTensors, pin_memory: bool = False
|
||||
) -> NestedTensors:
|
||||
"""
|
||||
Stack the inner dimensions that have the same shape in
|
||||
a nested list of tensors.
|
||||
|
||||
Thus, a dimension represented by a list means that the inner
|
||||
dimensions are different for each element along that dimension.
|
||||
"""
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
# TODO: Remove these once all models have been migrated
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(list[torch.Tensor], stacked)
|
||||
if len(tensors_) == 1:
|
||||
# An optimization when `tensors_` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(tensors_)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return tensors_[0].unsqueeze(0).contiguous()
|
||||
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
outputs = torch.empty(
|
||||
len(tensors_),
|
||||
*tensors_[0].shape,
|
||||
dtype=tensors_[0].dtype,
|
||||
device=tensors_[0].device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
return torch.stack(tensors_, out=outputs)
|
||||
|
||||
@staticmethod
|
||||
def batch(
|
||||
inputs_list: list["MultiModalKwargs"], pin_memory: bool = False
|
||||
) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
The resulting dictionary has the same keys as the inputs.
|
||||
If the corresponding value from each input is a tensor and they all
|
||||
share the same shape, the output value is a single batched tensor;
|
||||
otherwise, the output value is a list containing the original value
|
||||
from each input.
|
||||
"""
|
||||
if len(inputs_list) == 0:
|
||||
return {}
|
||||
|
||||
# We need to consider the case where each item in the batch
|
||||
# contains different modalities (i.e. different keys).
|
||||
item_lists = defaultdict[str, list[NestedTensors]](list)
|
||||
|
||||
for inputs in inputs_list:
|
||||
for k, v in inputs.items():
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list, pin_memory)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def as_kwargs(
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
) -> BatchedTensorInputs:
|
||||
return json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True),
|
||||
batched_inputs,
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key not in self:
|
||||
raise KeyError(
|
||||
|
||||
@ -19,7 +19,6 @@ from PIL import Image, UnidentifiedImageError
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||
@ -427,59 +426,25 @@ def group_mm_kwargs_by_modality(
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
if merge_by_field_config is None:
|
||||
raise RuntimeError(
|
||||
"`group_mm_kwargs_by_modality` now requires "
|
||||
"`merge_by_field_config` arg, please update your model runner "
|
||||
"according to https://github.com/vllm-project/vllm/pull/25676."
|
||||
)
|
||||
if merge_by_field_config is False:
|
||||
# TODO: After v0.13, remove merge_by_field_config attribute from model impls
|
||||
if merge_by_field_config is not None:
|
||||
logger.warning_once(
|
||||
"The legacy code for batching multi-modal kwargs is deprecated and "
|
||||
"will be removed in v0.12. Please update your model with "
|
||||
"`merge_by_field_config=True` to use the new code defined by "
|
||||
"`MultiModalFieldConfig`. You can refer to "
|
||||
"https://github.com/vllm-project/vllm/issues/26149 "
|
||||
"for some examples on how to do this."
|
||||
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
|
||||
"is deprecated and will be removed in v0.13."
|
||||
)
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst)
|
||||
mm_kwargs_data = mm_kwargs_items.get_data(
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
cpu_fields=multimodal_cpu_fields,
|
||||
)
|
||||
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
mm_kwargs_group = {
|
||||
k: json_map_leaves(
|
||||
lambda x: x.to(device=device, non_blocking=True)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x,
|
||||
v,
|
||||
)
|
||||
if k not in multimodal_cpu_fields
|
||||
else v
|
||||
for k, v in mm_kwargs_group.items()
|
||||
}
|
||||
else:
|
||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||
MultiModalKwargs.batch(
|
||||
[
|
||||
MultiModalKwargsItems.from_seq([item]).get_data()
|
||||
for item in items_lst
|
||||
],
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
|
||||
yield modality, len(items_lst), mm_kwargs_group
|
||||
yield modality, len(items_lst), mm_kwargs_data
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
|
||||
@ -27,7 +27,6 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalFieldElem,
|
||||
MultiModalFlatField,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField,
|
||||
@ -176,9 +175,6 @@ class MsgpackEncoder:
|
||||
if isinstance(obj, MultiModalKwargsItems):
|
||||
return self._encode_mm_items(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargs):
|
||||
return self._encode_mm_kwargs(obj)
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
@ -259,11 +255,6 @@ class MsgpackEncoder:
|
||||
"field": self._encode_mm_field(elem.field),
|
||||
}
|
||||
|
||||
def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
|
||||
return {
|
||||
modality: self._encode_nested_tensors(data) for modality, data in kw.items()
|
||||
}
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_tensor(nt)
|
||||
@ -325,8 +316,6 @@ class MsgpackDecoder:
|
||||
return self._decode_mm_item(obj)
|
||||
if issubclass(t, MultiModalKwargsItems):
|
||||
return self._decode_mm_items(obj)
|
||||
if issubclass(t, MultiModalKwargs):
|
||||
return self._decode_mm_kwargs(obj)
|
||||
if t is UtilityResult:
|
||||
return self._decode_utility_result(obj)
|
||||
return obj
|
||||
@ -414,14 +403,6 @@ class MsgpackDecoder:
|
||||
obj["field"] = factory_meth(None, *field_args).field
|
||||
return MultiModalFieldElem(**obj)
|
||||
|
||||
def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
|
||||
return MultiModalKwargs(
|
||||
{
|
||||
modality: self._decode_nested_tensors(data)
|
||||
for modality, data in obj.items()
|
||||
}
|
||||
)
|
||||
|
||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||
if isinstance(obj, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
|
||||
@ -2106,7 +2106,6 @@ class GPUModelRunner(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
):
|
||||
curr_group_outputs: list[torch.Tensor] = []
|
||||
@ -2133,7 +2132,6 @@ class GPUModelRunner(
|
||||
[video_mm_kwargs_item],
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
)
|
||||
)
|
||||
@ -3849,7 +3847,6 @@ class GPUModelRunner(
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
)
|
||||
)
|
||||
|
||||
@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
):
|
||||
# Run the encoder.
|
||||
@ -2058,7 +2057,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user