diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 6b9d388f2b9b4..2e032ac4ca526 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -20,7 +20,7 @@ from vllm.config.multimodal import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache -from vllm.multimodal.inputs import MultiModalInputs +from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.tokenizers import ( MistralTokenizer, @@ -418,4 +418,4 @@ def _assert_inputs_equal( a_data.pop(key, None) b_data.pop(key, None) - assert a_data == b_data, msg + assert batched_tensors_equal(a_data, b_data), msg diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index 553a5f719bd35..51071c93531de 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -5,6 +5,7 @@ import pytest from vllm.assets.video import VideoAsset from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import batched_tensors_equal from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend from ...utils import build_model_context @@ -103,7 +104,7 @@ def test_video_loader_consistency( dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs) assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"] - assert ( - static_outputs["mm_kwargs"].get_data() - == dynamic_outputs["mm_kwargs"].get_data() + assert batched_tensors_equal( + static_outputs["mm_kwargs"].get_data(), + dynamic_outputs["mm_kwargs"].get_data(), ) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 7628ab4fe2349..5d489549c5b46 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -130,10 +130,9 @@ def create_batched_mm_kwargs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, )["mm_kwargs"].require_data() - items = [item for modality in supported_mm_limits for item in mm_kwargs[modality]] + return group_mm_kwargs_by_modality( - items, - merge_by_field_config=model_cls.merge_by_field_config, + [item for modality in supported_mm_limits for item in mm_kwargs[modality]] ) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 2ddc93f8daf7b..e4fcc34740edb 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -85,12 +85,6 @@ def _dummy_items( (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - ( - _dummy_items( - {"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}} - ).get_data(), - 460, - ), # noqa: E501 ], ) def test_cache_item_size(item, expected_size): @@ -107,6 +101,9 @@ def test_cache_item_size(item, expected_size): cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) assert cache.currsize == expected_size + cache[""] = item.get_data() + assert cache.currsize == expected_size + def _create_vllm_config( *, diff --git a/tests/multimodal/test_inputs.py b/tests/multimodal/test_inputs.py deleted file mode 100644 index 88e92bee3a292..0000000000000 --- a/tests/multimodal/test_inputs.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors - -pytestmark = pytest.mark.cpu_test - - -def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): - assert type(expected) == type(actual) # noqa: E721 - if isinstance(expected, torch.Tensor): - assert torch.equal(expected, actual) - else: - for expected_item, actual_item in zip(expected, actual): - assert_nested_tensors_equal(expected_item, actual_item) - - -def assert_multimodal_inputs_equal( - expected: MultiModalKwargs, actual: MultiModalKwargs -): - assert set(expected.keys()) == set(actual.keys()) - for key in expected: - assert_nested_tensors_equal(expected[key], actual[key]) - - -def test_multimodal_input_batch_single_tensor(): - t = torch.rand([1, 2]) - result = MultiModalKwargs.batch([{"image": t}]) - assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) - - -def test_multimodal_input_batch_multiple_tensors(): - a = torch.rand([1, 1, 2]) - b = torch.rand([1, 1, 2]) - c = torch.rand([1, 1, 2]) - result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) - assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) - - -def test_multimodal_input_batch_multiple_heterogeneous_tensors(): - a = torch.rand([1, 2, 2]) - b = torch.rand([1, 3, 2]) - c = torch.rand([1, 4, 2]) - result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) - assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) - - -def test_multimodal_input_batch_nested_tensors(): - a = torch.rand([2, 3]) - b = torch.rand([2, 3]) - c = torch.rand([2, 3]) - result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}]) - assert_multimodal_inputs_equal( - result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])} - ) - - -def test_multimodal_input_batch_heterogeneous_lists(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 2, 3]) - c = torch.rand([1, 2, 3]) - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) - assert_multimodal_inputs_equal( - result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]} - ) - - -def test_multimodal_input_batch_multiple_batchable_lists(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 2, 3]) - c = torch.rand([1, 2, 3]) - d = torch.rand([1, 2, 3]) - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) - assert_multimodal_inputs_equal( - result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])} - ) - - -def test_multimodal_input_batch_mixed_stacking_depths(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 3, 3]) - c = torch.rand([1, 4, 3]) - - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) - assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) - - result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}]) - assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 019fb3e29ab91..a612ebd956282 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, NestedTensors, ) from vllm.multimodal.parse import ( @@ -305,7 +305,7 @@ class DeepseekOCRMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0f65683cf7c57..01b3e7827424d 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -78,7 +78,7 @@ class SupportsMultiModal(Protocol): `multimodal_config.mm_encoder_tp_mode="data"`. """ - merge_by_field_config: ClassVar[bool] = False + merge_by_field_config: ClassVar[bool] = True """ A flag that indicates which implementation of `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. diff --git a/vllm/model_executor/models/lightonocr.py b/vllm/model_executor/models/lightonocr.py index 9839e4f8f707e..353ee7806b1b1 100644 --- a/vllm/model_executor/models/lightonocr.py +++ b/vllm/model_executor/models/lightonocr.py @@ -28,7 +28,7 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -103,7 +103,7 @@ class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingIn self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 891a9ce080233..c4198d36b392e 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -52,7 +52,6 @@ from vllm.multimodal.evs import ( from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, MultiModalKwargsItems, VideoItem, ) @@ -849,17 +848,18 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py index 4338918663378..b92f0c9dac32b 100644 --- a/vllm/model_executor/models/opencua.py +++ b/vllm/model_executor/models/opencua.py @@ -23,7 +23,7 @@ from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import ( @@ -153,7 +153,7 @@ class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]) self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 5256d8ba7fd86..1df5ff62fa5b5 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -62,7 +62,7 @@ from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import ( ImageProcessorItems, @@ -307,7 +307,7 @@ class PaddleOCRVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index ec5d0fa6226dd..9fa32f01d37a0 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -40,7 +40,6 @@ from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -252,6 +251,8 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn dummy_inputs=PaliGemmaDummyInputsBuilder, ) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -327,9 +328,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP return None if pixel_values is not None: - pixel_values = flatten_bn(pixel_values, concat=True) - h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs( type="pixel_values", data=pixel_values, @@ -337,8 +337,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) if image_embeds is not None: - image_embeds = flatten_bn(image_embeds, concat=True) - return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6ca490f467634..cb521ebdf0afb 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -77,7 +77,7 @@ from vllm.multimodal.evs import ( from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate @@ -973,7 +973,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 97f6aa461b90c..67bdf5e1557f9 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -25,7 +25,6 @@ from .inputs import ( MultiModalBatchedField, MultiModalFeatureSpec, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, NestedTensors, @@ -90,7 +89,6 @@ MultiModalCacheValue: TypeAlias = ( | MultiModalProcessorCacheItemMetadata | MultiModalKwargsItems | MultiModalKwargsItem - | MultiModalKwargs | Mapping[str, NestedTensors] ) @@ -108,12 +106,7 @@ class MultiModalCache: # These are not subclasses of dict if isinstance( leaf, - ( - MultiModalKwargs, - MultiModalKwargsItems, - MultiModalKwargsItem, - MultiModalFieldElem, - ), + (MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem), ): return cls.get_item_size(leaf.data) # type: ignore diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 397684fa2f83c..32f15240cb7da 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from collections.abc import Mapping, Sequence +from collections.abc import Mapping, Sequence, Set from dataclasses import dataclass from functools import partial from itertools import accumulate @@ -201,8 +201,10 @@ Uses a list instead of a tensor if the dimensions of each element do not match. def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: - """Equality check between - [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.""" + """ + Equality check between + [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects. + """ if isinstance(a, torch.Tensor): return isinstance(b, torch.Tensor) and torch.equal(a, b) elif isinstance(b, torch.Tensor): @@ -224,10 +226,24 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via -[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. +[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data]. """ +def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool: + """ + Equality check between + [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects. + """ + for k in a: + if k not in b: + return False + if not nested_tensors_equal(a[k], b[k]): + return False + + return True + + @dataclass class MultiModalFeatureSpec: """ @@ -823,7 +839,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): return self # type: ignore[return-value] - def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": + def get_data( + self, + *, + device: torch.types.Device = None, + pin_memory: bool = False, + cpu_fields: Set[str] = frozenset(), + ) -> BatchedTensorInputs: + """Construct a dictionary of keyword arguments to pass to the model.""" elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) for modality, items in self.items(): for i, item in enumerate(items): @@ -835,12 +858,23 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): for key, elem in item.items(): elems_by_key[key].append(elem) - return MultiModalKwargs( - { - key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() - } - ) + data = { + key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + } + + if device is not None: + for k in data.keys() - cpu_fields: + data[k] = json_map_leaves( + ( + lambda x: x.to(device=device, non_blocking=True) + if isinstance(x, torch.Tensor) + else x + ), + data[k], + ) + + return data MultiModalKwargsOptionalItems: TypeAlias = ( @@ -849,6 +883,7 @@ MultiModalKwargsOptionalItems: TypeAlias = ( ) +@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.13.") class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to @@ -882,91 +917,6 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ): return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory) - @staticmethod - def _try_stack( - nested_tensors: NestedTensors, pin_memory: bool = False - ) -> NestedTensors: - """ - Stack the inner dimensions that have the same shape in - a nested list of tensors. - - Thus, a dimension represented by a list means that the inner - dimensions are different for each element along that dimension. - """ - if isinstance(nested_tensors, torch.Tensor): - return nested_tensors - - # TODO: Remove these once all models have been migrated - if isinstance(nested_tensors, np.ndarray): - return torch.from_numpy(nested_tensors) - if isinstance(nested_tensors, (int, float)): - return torch.tensor(nested_tensors) - - stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors] - if not is_list_of(stacked, torch.Tensor, check="all"): - # Only tensors (not lists) can be stacked. - return stacked - - tensors_ = cast(list[torch.Tensor], stacked) - if len(tensors_) == 1: - # An optimization when `tensors_` contains only one tensor: - # - produce exactly same result as `torch.stack(tensors_)` - # - will achieve zero-copy if the tensor is contiguous - return tensors_[0].unsqueeze(0).contiguous() - - if any(t.shape != tensors_[0].shape for t in tensors_): - # The tensors have incompatible shapes and can't be stacked. - return tensors_ - - outputs = torch.empty( - len(tensors_), - *tensors_[0].shape, - dtype=tensors_[0].dtype, - device=tensors_[0].device, - pin_memory=pin_memory, - ) - return torch.stack(tensors_, out=outputs) - - @staticmethod - def batch( - inputs_list: list["MultiModalKwargs"], pin_memory: bool = False - ) -> BatchedTensorInputs: - """ - Batch multiple inputs together into a dictionary. - - The resulting dictionary has the same keys as the inputs. - If the corresponding value from each input is a tensor and they all - share the same shape, the output value is a single batched tensor; - otherwise, the output value is a list containing the original value - from each input. - """ - if len(inputs_list) == 0: - return {} - - # We need to consider the case where each item in the batch - # contains different modalities (i.e. different keys). - item_lists = defaultdict[str, list[NestedTensors]](list) - - for inputs in inputs_list: - for k, v in inputs.items(): - item_lists[k].append(v) - - return { - k: MultiModalKwargs._try_stack(item_list, pin_memory) - for k, item_list in item_lists.items() - } - - @staticmethod - def as_kwargs( - batched_inputs: BatchedTensorInputs, - *, - device: torch.types.Device, - ) -> BatchedTensorInputs: - return json_map_leaves( - lambda x: x.to(device=device, non_blocking=True), - batched_inputs, - ) - def __getitem__(self, key: str): if key not in self: raise KeyError( diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 1840220854858..f8e8847e8e609 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -19,7 +19,6 @@ from PIL import Image, UnidentifiedImageError import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger -from vllm.utils.jsontree import json_map_leaves from vllm.utils.registry import ExtensionManager from .audio import AudioEmbeddingMediaIO, AudioMediaIO @@ -427,59 +426,25 @@ def group_mm_kwargs_by_modality( Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ - if merge_by_field_config is None: - raise RuntimeError( - "`group_mm_kwargs_by_modality` now requires " - "`merge_by_field_config` arg, please update your model runner " - "according to https://github.com/vllm-project/vllm/pull/25676." - ) - if merge_by_field_config is False: + # TODO: After v0.13, remove merge_by_field_config attribute from model impls + if merge_by_field_config is not None: logger.warning_once( - "The legacy code for batching multi-modal kwargs is deprecated and " - "will be removed in v0.12. Please update your model with " - "`merge_by_field_config=True` to use the new code defined by " - "`MultiModalFieldConfig`. You can refer to " - "https://github.com/vllm-project/vllm/issues/26149 " - "for some examples on how to do this." + "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` " + "is deprecated and will be removed in v0.13." ) - from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems + from vllm.multimodal.inputs import MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) + mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst) + mm_kwargs_data = mm_kwargs_items.get_data( + device=device, + pin_memory=pin_memory, + cpu_fields=multimodal_cpu_fields, + ) - if merge_by_field_config: - mm_kwargs_group: BatchedTensorInputs = dict( - MultiModalKwargsItems.from_seq(items_lst).get_data( - pin_memory=pin_memory - ) - ) - - if device is not None: - mm_kwargs_group = { - k: json_map_leaves( - lambda x: x.to(device=device, non_blocking=True) - if isinstance(x, torch.Tensor) - else x, - v, - ) - if k not in multimodal_cpu_fields - else v - for k, v in mm_kwargs_group.items() - } - else: - mm_kwargs_group = MultiModalKwargs.as_kwargs( - MultiModalKwargs.batch( - [ - MultiModalKwargsItems.from_seq([item]).get_data() - for item in items_lst - ], - pin_memory=pin_memory, - ), - device=device, - ) - - yield modality, len(items_lst), mm_kwargs_group + yield modality, len(items_lst), mm_kwargs_data def fetch_audio( diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0a6806390451d..14ae487f3eb13 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -27,7 +27,6 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField, @@ -176,9 +175,6 @@ class MsgpackEncoder: if isinstance(obj, MultiModalKwargsItems): return self._encode_mm_items(obj) - if isinstance(obj, MultiModalKwargs): - return self._encode_mm_kwargs(obj) - if isinstance(obj, UtilityResult): result = obj.result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: @@ -259,11 +255,6 @@ class MsgpackEncoder: "field": self._encode_mm_field(elem.field), } - def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: - return { - modality: self._encode_nested_tensors(data) for modality, data in kw.items() - } - def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): return self._encode_tensor(nt) @@ -325,8 +316,6 @@ class MsgpackDecoder: return self._decode_mm_item(obj) if issubclass(t, MultiModalKwargsItems): return self._decode_mm_items(obj) - if issubclass(t, MultiModalKwargs): - return self._decode_mm_kwargs(obj) if t is UtilityResult: return self._decode_utility_result(obj) return obj @@ -414,14 +403,6 @@ class MsgpackDecoder: obj["field"] = factory_meth(None, *field_args).field return MultiModalFieldElem(**obj) - def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: - return MultiModalKwargs( - { - modality: self._decode_nested_tensors(data) - for modality, data in obj.items() - } - ) - def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a7eb9cdae8b10..58043a42da94e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2106,7 +2106,6 @@ class GPUModelRunner( mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ): curr_group_outputs: list[torch.Tensor] = [] @@ -2133,7 +2132,6 @@ class GPUModelRunner( [video_mm_kwargs_item], device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) @@ -3849,7 +3847,6 @@ class GPUModelRunner( dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f3dd9aa96d2ae..292f12969aaed 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ): # Run the encoder. @@ -2058,7 +2057,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ) )