mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
de9cb61763
commit
4dff91c93d
@ -7,9 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@ -42,13 +40,7 @@ def make_request(
|
||||
if mm_positions is None:
|
||||
mm_kwargs = None
|
||||
else:
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(request_id=request_id,
|
||||
|
||||
@ -9,9 +9,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@ -37,13 +35,7 @@ def make_request(
|
||||
if mm_positions is None:
|
||||
mm_kwargs = None
|
||||
else:
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(request_id=request_id,
|
||||
|
||||
@ -8,9 +8,7 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@ -1328,13 +1326,7 @@ def create_requests_with_priority(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
|
||||
@ -6,9 +6,7 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
@ -143,13 +141,7 @@ def create_requests(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
mm_hashes = ["hash"] * len(mm_position)
|
||||
else:
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
@ -218,7 +218,7 @@ class MultiModalFieldElem:
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: Optional[NestedTensors]
|
||||
data: NestedTensors
|
||||
"""
|
||||
The tensor data of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
@ -315,13 +315,8 @@ class BaseMultiModalField(ABC):
|
||||
if len(set(field_types)) > 1:
|
||||
raise ValueError(f"Cannot merge different {field_types=}")
|
||||
|
||||
validated_data = list[NestedTensors]()
|
||||
for i, elem in enumerate(elems):
|
||||
assert elem.data is not None, (
|
||||
f"Cannot merge with empty `elems[{i}]`")
|
||||
validated_data.append(elem.data)
|
||||
|
||||
return self._reduce_data(validated_data, pin_memory=pin_memory)
|
||||
batch = [elem.data for elem in elems]
|
||||
return self._reduce_data(batch, pin_memory=pin_memory)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -643,6 +638,17 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def dummy(modality: str):
|
||||
"""Convenience class for testing."""
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="dummy",
|
||||
data=torch.empty(1),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
return MultiModalKwargsItem.from_elems([mm_elem])
|
||||
|
||||
@staticmethod
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
@ -654,46 +660,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
self._modality = next(iter(modalities))
|
||||
|
||||
self._is_empty = any(elem.data is None for elem in self.values())
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self._modality
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return self._is_empty
|
||||
|
||||
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
|
||||
if self._is_empty:
|
||||
return None
|
||||
|
||||
out_data = dict[str, NestedTensors]()
|
||||
for key, elem in self.items():
|
||||
assert elem.data is not None, (
|
||||
f"Cannot get data of empty `elem[{key!r}]`")
|
||||
out_data[key] = elem.data
|
||||
|
||||
return out_data
|
||||
|
||||
def require_data(self) -> Mapping[str, NestedTensors]:
|
||||
if (data := self.get_data()) is None:
|
||||
raise RuntimeError("Cannot get data of empty item")
|
||||
|
||||
return data
|
||||
|
||||
# These methods create a new item to avoid mutating cached items in place
|
||||
def with_data(self, data: Mapping[str, NestedTensors]):
|
||||
return MultiModalKwargsItem({
|
||||
key: replace(elem, data=data[key])
|
||||
for key, elem in self.items()
|
||||
})
|
||||
|
||||
def without_data(self):
|
||||
return MultiModalKwargsItem({
|
||||
key: replace(elem, data=None)
|
||||
for key, elem in self.items()
|
||||
})
|
||||
def get_data(self) -> Mapping[str, NestedTensors]:
|
||||
return {key: elem.data for key, elem in self.items()}
|
||||
|
||||
|
||||
# NOTE: UserDict is for V0 compatibility.
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
@ -47,7 +48,7 @@ class EngineCoreRequest(
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: Optional[list[MultiModalKwargsItem]]
|
||||
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -58,21 +59,21 @@ class MultiModalInputCacheClient:
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_kwargs: Sequence[MultiModalKwargsItem],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
) -> list[Optional[MultiModalKwargsItem]]:
|
||||
if not self.enabled:
|
||||
return mm_kwargs
|
||||
return list(mm_kwargs)
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
out_mm_items = list[Optional[MultiModalKwargsItem]]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
out_mm_items.append(mm_item.without_data())
|
||||
out_mm_items.append(None)
|
||||
else:
|
||||
self.mm_cache[mm_hash] = \
|
||||
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
|
||||
MultiModalCacheItemMetadata.wraps(mm_item)
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
@ -91,25 +92,27 @@ class MultiModalInputCacheServer:
|
||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||
model_config.get_mm_input_cache_gb(),
|
||||
Mapping[str, NestedTensors],
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
if not self.enabled:
|
||||
return mm_kwargs
|
||||
mm_kwargs_lst = list(mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
|
||||
return mm_kwargs_lst
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if (mm_data := mm_item.get_data()) is None:
|
||||
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
|
||||
if mm_item is None:
|
||||
out_mm_items.append(self.mm_cache[mm_hash])
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_data
|
||||
self.mm_cache[mm_hash] = mm_item
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
@ -295,7 +296,7 @@ class Processor:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
|
||||
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
@ -308,7 +309,7 @@ class Processor:
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
sorted_mm_inputs = [
|
||||
orig_sorted_mm_inputs = [
|
||||
decoder_mm_inputs.get_item(modality, idx)
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
@ -323,9 +324,12 @@ class Processor:
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||
sorted_mm_inputs,
|
||||
orig_sorted_mm_inputs,
|
||||
sorted_mm_hashes,
|
||||
)
|
||||
else:
|
||||
assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem)
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
|
||||
@ -125,14 +125,17 @@ class Request:
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
mm_kwargs_lst = list(request.mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
else:
|
||||
mm_kwargs_lst = None
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_kwargs=request.mm_kwargs,
|
||||
multi_modal_kwargs=mm_kwargs_lst,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
|
||||
@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for item in self.requests[req_id].mm_kwargs:
|
||||
mm_input = item.require_data()
|
||||
for mm_item in self.requests[req_id].mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.append(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user