diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e0b91e6dd7ee..47c74aff1e75 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -7,9 +7,7 @@ import pytest import torch from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -42,13 +40,7 @@ def make_request( if mm_positions is None: mm_kwargs = None else: - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_positions) return Request(request_id=request_id, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 28cfca6767b1..89824768ed90 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -9,9 +9,7 @@ import pytest import torch from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool @@ -37,13 +35,7 @@ def make_request( if mm_positions is None: mm_kwargs = None else: - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_positions) return Request(request_id=request_id, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ac70c90d92ad..23762a0fb622 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -8,9 +8,7 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler @@ -1328,13 +1326,7 @@ def create_requests_with_priority( for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_position) else: mm_position = None diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 52093d3d381a..849c3f59ae52 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -6,9 +6,7 @@ import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) @@ -143,13 +141,7 @@ def create_requests( for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] - mm_elem = MultiModalFieldElem( - modality="dummy_m", - key="dummy_k", - data=None, - field=MultiModalBatchedField(), - ) - mm_item = MultiModalKwargsItem.from_elems([mm_elem]) + mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_position) mm_hashes = ["hash"] * len(mm_position) else: diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 0bbac45c121b..a33ce146995d 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence -from dataclasses import dataclass, replace +from dataclasses import dataclass from functools import partial from itertools import accumulate from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, @@ -218,7 +218,7 @@ class MultiModalFieldElem: i.e. the name of the keyword argument to be passed to the model. """ - data: Optional[NestedTensors] + data: NestedTensors """ The tensor data of this field in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], @@ -315,13 +315,8 @@ class BaseMultiModalField(ABC): if len(set(field_types)) > 1: raise ValueError(f"Cannot merge different {field_types=}") - validated_data = list[NestedTensors]() - for i, elem in enumerate(elems): - assert elem.data is not None, ( - f"Cannot merge with empty `elems[{i}]`") - validated_data.append(elem.data) - - return self._reduce_data(validated_data, pin_memory=pin_memory) + batch = [elem.data for elem in elems] + return self._reduce_data(batch, pin_memory=pin_memory) @dataclass(frozen=True) @@ -643,6 +638,17 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. """ + @staticmethod + def dummy(modality: str): + """Convenience class for testing.""" + mm_elem = MultiModalFieldElem( + modality=modality, + key="dummy", + data=torch.empty(1), + field=MultiModalSharedField(1), + ) + return MultiModalKwargsItem.from_elems([mm_elem]) + @staticmethod def from_elems(elems: Sequence[MultiModalFieldElem]): return MultiModalKwargsItem({elem.key: elem for elem in elems}) @@ -654,46 +660,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): assert len(modalities) == 1, f"Found different modalities={modalities}" self._modality = next(iter(modalities)) - self._is_empty = any(elem.data is None for elem in self.values()) - @property def modality(self) -> str: return self._modality - @property - def is_empty(self) -> bool: - return self._is_empty - - def get_data(self) -> Optional[Mapping[str, NestedTensors]]: - if self._is_empty: - return None - - out_data = dict[str, NestedTensors]() - for key, elem in self.items(): - assert elem.data is not None, ( - f"Cannot get data of empty `elem[{key!r}]`") - out_data[key] = elem.data - - return out_data - - def require_data(self) -> Mapping[str, NestedTensors]: - if (data := self.get_data()) is None: - raise RuntimeError("Cannot get data of empty item") - - return data - - # These methods create a new item to avoid mutating cached items in place - def with_data(self, data: Mapping[str, NestedTensors]): - return MultiModalKwargsItem({ - key: replace(elem, data=data[key]) - for key, elem in self.items() - }) - - def without_data(self): - return MultiModalKwargsItem({ - key: replace(elem, data=None) - for key, elem in self.items() - }) + def get_data(self) -> Mapping[str, NestedTensors]: + return {key: elem.data for key, elem in self.items()} # NOTE: UserDict is for V0 compatibility. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index b29394f3e676..f7ec982db41b 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Sequence from typing import Any, Optional, Union import msgspec @@ -47,7 +48,7 @@ class EngineCoreRequest( request_id: str prompt_token_ids: list[int] - mm_kwargs: Optional[list[MultiModalKwargsItem]] + mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: Optional[SamplingParams] diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 1fed74330f0e..aa7dc62fd4ac 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Mapping -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional from vllm.multimodal import MultiModalRegistry from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors +from vllm.multimodal.inputs import MultiModalKwargsItem +from vllm.utils import is_list_of if TYPE_CHECKING: from vllm.config import ModelConfig @@ -58,21 +59,21 @@ class MultiModalInputCacheClient: def get_and_update( self, - mm_kwargs: list[MultiModalKwargsItem], + mm_kwargs: Sequence[MultiModalKwargsItem], mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: + ) -> list[Optional[MultiModalKwargsItem]]: if not self.enabled: - return mm_kwargs + return list(mm_kwargs) assert len(mm_kwargs) == len(mm_hashes) - out_mm_items = list[MultiModalKwargsItem]() + out_mm_items = list[Optional[MultiModalKwargsItem]]() for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(mm_item.without_data()) + out_mm_items.append(None) else: self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item.require_data()) + MultiModalCacheItemMetadata.wraps(mm_item) out_mm_items.append(mm_item) return out_mm_items @@ -91,25 +92,27 @@ class MultiModalInputCacheServer: self.enabled = mm_registry.enable_mm_input_cache(model_config) self.mm_cache = MultiModalCache.get_lru_cache( model_config.get_mm_input_cache_gb(), - Mapping[str, NestedTensors], + MultiModalKwargsItem, ) def get_and_update( self, - mm_kwargs: list[MultiModalKwargsItem], + mm_kwargs: Sequence[Optional[MultiModalKwargsItem]], mm_hashes: list[str], ) -> list[MultiModalKwargsItem]: if not self.enabled: - return mm_kwargs + mm_kwargs_lst = list(mm_kwargs) + assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem) + return mm_kwargs_lst assert len(mm_kwargs) == len(mm_hashes) out_mm_items = list[MultiModalKwargsItem]() for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if (mm_data := mm_item.get_data()) is None: - out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash])) + if mm_item is None: + out_mm_items.append(self.mm_cache[mm_hash]) else: - self.mm_cache[mm_hash] = mm_data + self.mm_cache[mm_hash] = mm_item out_mm_items.append(mm_item) return out_mm_items diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 376c76a7e728..c6a23cdbf65a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import is_list_of from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( @@ -295,7 +296,7 @@ class Processor: pooling_params = params.clone() # Multimodal related. - sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None + sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_hashes: Optional[list[str]] = None if decoder_inputs["type"] == "multimodal": @@ -308,7 +309,7 @@ class Processor: # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - sorted_mm_inputs = [ + orig_sorted_mm_inputs = [ decoder_mm_inputs.get_item(modality, idx) for modality, idx in sorted_mm_idxs ] @@ -323,9 +324,12 @@ class Processor: if sorted_mm_hashes is not None: sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - sorted_mm_inputs, + orig_sorted_mm_inputs, sorted_mm_hashes, ) + else: + assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem) + sorted_mm_inputs = orig_sorted_mm_inputs return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 562925bde669..8b703b6191fe 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -125,14 +125,17 @@ class Request: block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] ) -> "Request": if request.mm_kwargs is not None: - assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), ( + mm_kwargs_lst = list(request.mm_kwargs) + assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), ( "mm_kwargs was not updated in EngineCore.add_request") + else: + mm_kwargs_lst = None return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, - multi_modal_kwargs=request.mm_kwargs, + multi_modal_kwargs=mm_kwargs_lst, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c919b392fbd..5ee44a82574c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for item in self.requests[req_id].mm_kwargs: - mm_input = item.require_data() + for mm_item in self.requests[req_id].mm_kwargs: + mm_input = mm_item.get_data() if mm_input.get("image_grid_thw") is not None: image_grid_thw.append( mm_input["image_grid_thw"].tolist())