[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-16 19:30:49 +08:00 committed by GitHub
parent de9cb61763
commit 4dff91c93d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 59 additions and 108 deletions

View File

@ -7,9 +7,7 @@ import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -42,13 +40,7 @@ def make_request(
if mm_positions is None:
mm_kwargs = None
else:
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
return Request(request_id=request_id,

View File

@ -9,9 +9,7 @@ import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
@ -37,13 +35,7 @@ def make_request(
if mm_positions is None:
mm_kwargs = None
else:
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
return Request(request_id=request_id,

View File

@ -8,9 +8,7 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
@ -1328,13 +1326,7 @@ def create_requests_with_priority(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None

View File

@ -6,9 +6,7 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
@ -143,13 +141,7 @@ def create_requests(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else:

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, replace
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
@ -218,7 +218,7 @@ class MultiModalFieldElem:
i.e. the name of the keyword argument to be passed to the model.
"""
data: Optional[NestedTensors]
data: NestedTensors
"""
The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
@ -315,13 +315,8 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
validated_data = list[NestedTensors]()
for i, elem in enumerate(elems):
assert elem.data is not None, (
f"Cannot merge with empty `elems[{i}]`")
validated_data.append(elem.data)
return self._reduce_data(validated_data, pin_memory=pin_memory)
batch = [elem.data for elem in elems]
return self._reduce_data(batch, pin_memory=pin_memory)
@dataclass(frozen=True)
@ -643,6 +638,17 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@staticmethod
def dummy(modality: str):
"""Convenience class for testing."""
mm_elem = MultiModalFieldElem(
modality=modality,
key="dummy",
data=torch.empty(1),
field=MultiModalSharedField(1),
)
return MultiModalKwargsItem.from_elems([mm_elem])
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
@ -654,46 +660,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
assert len(modalities) == 1, f"Found different modalities={modalities}"
self._modality = next(iter(modalities))
self._is_empty = any(elem.data is None for elem in self.values())
@property
def modality(self) -> str:
return self._modality
@property
def is_empty(self) -> bool:
return self._is_empty
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
if self._is_empty:
return None
out_data = dict[str, NestedTensors]()
for key, elem in self.items():
assert elem.data is not None, (
f"Cannot get data of empty `elem[{key!r}]`")
out_data[key] = elem.data
return out_data
def require_data(self) -> Mapping[str, NestedTensors]:
if (data := self.get_data()) is None:
raise RuntimeError("Cannot get data of empty item")
return data
# These methods create a new item to avoid mutating cached items in place
def with_data(self, data: Mapping[str, NestedTensors]):
return MultiModalKwargsItem({
key: replace(elem, data=data[key])
for key, elem in self.items()
})
def without_data(self):
return MultiModalKwargsItem({
key: replace(elem, data=None)
for key, elem in self.items()
})
def get_data(self) -> Mapping[str, NestedTensors]:
return {key: elem.data for key, elem in self.items()}
# NOTE: UserDict is for V0 compatibility.

View File

@ -3,6 +3,7 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
@ -47,7 +48,7 @@ class EngineCoreRequest(
request_id: str
prompt_token_ids: list[int]
mm_kwargs: Optional[list[MultiModalKwargsItem]]
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams]

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import TYPE_CHECKING
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -58,21 +59,21 @@ class MultiModalInputCacheClient:
def get_and_update(
self,
mm_kwargs: list[MultiModalKwargsItem],
mm_kwargs: Sequence[MultiModalKwargsItem],
mm_hashes: list[str],
) -> list[MultiModalKwargsItem]:
) -> list[Optional[MultiModalKwargsItem]]:
if not self.enabled:
return mm_kwargs
return list(mm_kwargs)
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
out_mm_items = list[Optional[MultiModalKwargsItem]]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
out_mm_items.append(mm_item.without_data())
out_mm_items.append(None)
else:
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
MultiModalCacheItemMetadata.wraps(mm_item)
out_mm_items.append(mm_item)
return out_mm_items
@ -91,25 +92,27 @@ class MultiModalInputCacheServer:
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
Mapping[str, NestedTensors],
MultiModalKwargsItem,
)
def get_and_update(
self,
mm_kwargs: list[MultiModalKwargsItem],
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
mm_hashes: list[str],
) -> list[MultiModalKwargsItem]:
if not self.enabled:
return mm_kwargs
mm_kwargs_lst = list(mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
return mm_kwargs_lst
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if (mm_data := mm_item.get_data()) is None:
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
if mm_item is None:
out_mm_items.append(self.mm_cache[mm_hash])
else:
self.mm_cache[mm_hash] = mm_data
self.mm_cache[mm_hash] = mm_item
out_mm_items.append(mm_item)
return out_mm_items

View File

@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
@ -295,7 +296,7 @@ class Processor:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
@ -308,7 +309,7 @@ class Processor:
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
decoder_mm_inputs.get_item(modality, idx)
for modality, idx in sorted_mm_idxs
]
@ -323,9 +324,12 @@ class Processor:
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
sorted_mm_inputs,
orig_sorted_mm_inputs,
sorted_mm_hashes,
)
else:
assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem)
sorted_mm_inputs = orig_sorted_mm_inputs
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,

View File

@ -125,14 +125,17 @@ class Request:
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
mm_kwargs_lst = list(request.mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
else:
mm_kwargs_lst = None
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
multi_modal_kwargs=request.mm_kwargs,
multi_modal_kwargs=mm_kwargs_lst,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,

View File

@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
for mm_item in self.requests[req_id].mm_kwargs:
mm_input = mm_item.get_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())