[Core] Use individual MM items in P0/P1 cache and model runner (#22570)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-13 22:18:07 +08:00 committed by GitHub
parent 20d65aa755
commit 19b927e52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 549 additions and 486 deletions

View File

@ -5,7 +5,7 @@ import base64
import mimetypes
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import pytest
@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
assert metadata_sync == metadata_async
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
# Used for `test_argsort_mm_positions`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
expected_modality_idxs: list[tuple[str, int]]
def test_merge_and_sort_multimodal_metadata():
def test_argsort_mm_positions():
test_cases = [
# Single modality should return result as is but flattened
# Single modality
## Internally sorted
TestCase(
mm_positions={
"image": [
@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=3, length=2),
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
expected_modality_idxs=[
("image", 0),
("image", 1),
],
expected_hashes=["hash1", "hash2"],
),
# Single modality without hashes return None for mm hash.
## Internally unsorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=3, length=2),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
]
},
mm_hashes=None,
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
expected_modality_idxs=[
("image", 1),
("image", 0),
],
expected_hashes=None,
),
# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
# Two modalities
## Internally sorted
TestCase(
mm_positions={
"image": [
@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
expected_modality_idxs=[
("audio", 0),
("audio", 1),
("image", 0),
("image", 1),
],
),
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
## Interleaved, internally sorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes=None,
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
expected_modality_idxs=[
("image", 0),
("audio", 0),
("image", 1),
("audio", 1),
],
),
## Interleaved, internally unsorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=0, length=4),
],
"audio": [
PlaceholderRange(offset=11, length=4),
PlaceholderRange(offset=5, length=2),
]
},
expected_modality_idxs=[
("image", 1),
("audio", 1),
("image", 0),
("audio", 0),
],
expected_hashes=None,
),
# Three modalities
## Internally sorted
TestCase(
mm_positions={
"image": [
@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=12, length=6),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=[
"audio", "video", "video", "video", "image", "image"
],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
expected_modality_idxs=[
("audio", 0),
("video", 0),
("video", 1),
("video", 2),
("image", 0),
("image", 1),
],
),
]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
def test_merge_and_sort_multimodal_metadata_with_interleaving():
test_cases = [
# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
],
),
# <image> <image> <audio> <video> <image>
## Interleaved, internally sorted
TestCase(
mm_positions={
"image": [
@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes=None,
expected_modalities=["image", "image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
expected_modality_idxs=[
("image", 0),
("image", 1),
("audio", 0),
("video", 0),
("image", 2),
],
expected_hashes=None,
),
# <image> <audio> <video> <image> with hashes
## Interleaved, internally sunorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4),
PlaceholderRange(offset=20, length=4),
PlaceholderRange(offset=2, length=3),
],
"audio": [
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=5, length=2),
],
"video": [
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1"],
},
expected_modalities=["image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
expected_modality_idxs=[
("image", 0),
("image", 2),
("audio", 0),
("video", 0),
("image", 1),
],
),
]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
for mm_positions, expected_modality_idxs in test_cases:
modality_idxs = argsort_mm_positions(mm_positions)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
assert modality_idxs == expected_modality_idxs
class SimpleLinearModel(torch.nn.Module):

View File

@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import Optional
import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -27,20 +30,29 @@ from vllm.v1.request import Request
# yapf: enable
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
cache_salt=None):
def make_request(
request_id: str,
prompt_token_ids: list[int],
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
multi_modal_inputs = None
mm_kwargs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
def test_generate_block_hash_extra_keys():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
def test_generate_block_hash_extra_keys_no_mm_inputs():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
def test_generate_block_hash_extra_keys_cache_salt():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# works together with other extra keys
request_mm = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
request1 = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
mm_hashes=["hash1", "hash2"],
)
request2 = make_request(
request_id=1,
request_id="1",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[],
mm_positions=None,
mm_hashes=None,

View File

@ -9,7 +9,9 @@ import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None):
def make_request(
request_id: str,
prompt_token_ids: list[int],
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
multi_modal_inputs = None
mm_kwargs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17,

View File

@ -8,7 +8,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
@ -1304,7 +1306,7 @@ def create_requests_with_priority(
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
@ -1323,16 +1325,23 @@ def create_requests_with_priority(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_inputs = None
mm_kwargs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_kwargs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,

View File

@ -6,7 +6,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
@ -115,7 +117,7 @@ def create_scheduler(
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
@ -129,10 +131,17 @@ def create_requests(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_inputs = None
mm_kwargs = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
@ -140,7 +149,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,

View File

@ -35,7 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(),

View File

@ -52,7 +52,7 @@ def make_request(
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=params,

View File

@ -53,7 +53,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
@ -402,7 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
@ -567,7 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id=request_id,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=eos_token_id,
@ -666,7 +666,7 @@ def test_stop_string(include_stop_str_in_output: bool,
request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
@ -782,7 +782,7 @@ def test_iteration_stats(dummy_test_vectors):
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,

View File

@ -154,7 +154,7 @@ def create_request(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,

View File

@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_inputs=[],
mm_kwargs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),

View File

@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[],
mm_kwargs=[],
mm_positions=[],
block_ids=([], ),
generator=None,

View File

@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_inputs=[],
mm_kwargs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
@ -198,7 +198,7 @@ A dictionary containing nested tensors which have been batched via
"""
@dataclass(frozen=True)
@dataclass
class MultiModalFieldElem:
"""
Represents a keyword argument corresponding to a multi-modal item
@ -218,11 +218,14 @@ class MultiModalFieldElem:
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
data: Optional[NestedTensors]
"""
The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the value of the keyword argument to be passed to the model.
It may be set to `None` if it is determined that the item is cached
in `EngineCore`.
"""
field: "BaseMultiModalField"
@ -235,8 +238,15 @@ class MultiModalFieldElem:
if not isinstance(other, self.__class__):
return False
if self.data is None:
data_equal = other.data is None
elif other.data is None:
data_equal = self.data is None
else:
data_equal = nested_tensors_equal(self.data, other.data)
return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data)
and data_equal
and type(self.field) == type(other.field)) # noqa: E721
@ -280,10 +290,20 @@ class BaseMultiModalField(ABC):
raise NotImplementedError
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
raise NotImplementedError
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
def reduce_data(
self,
elems: list[MultiModalFieldElem],
*,
pin_memory: bool = False,
) -> NestedTensors:
"""
Merge the data from multiple instances of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
@ -295,7 +315,13 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
return self._reduce_data([item.data for item in elems])
validated_data = list[NestedTensors]()
for i, elem in enumerate(elems):
assert elem.data is not None, (
f"Cannot merge with empty `elems[{i}]`")
validated_data.append(elem.data)
return self._reduce_data(validated_data, pin_memory=pin_memory)
@dataclass(frozen=True)
@ -314,7 +340,12 @@ class MultiModalBatchedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
@ -323,7 +354,11 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
out = torch.empty((len(batch), *batch[0].shape),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.stack(batch, out=out)
return batch
@ -350,7 +385,12 @@ class MultiModalFlatField(BaseMultiModalField):
"torch.Tensor is required for multiple slices"
return [field_factory(data[cast(slice, s)]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
@ -358,13 +398,21 @@ class MultiModalFlatField(BaseMultiModalField):
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
def _expect_same_shape(tensor: torch.Tensor):
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
dim = self.dim + (self.dim < 0) * len(batch[0].shape)
first_shape = _expect_same_shape(batch[0])
def _shape_before_after(tensor: torch.Tensor):
return tensor.shape[:dim], tensor.shape[dim + 1:]
if all(_expect_same_shape(elem) == first_shape for elem in batch):
return torch.concat(batch, dim=self.dim)
first_shape = _shape_before_after(batch[0])
if all(_shape_before_after(elem) == first_shape for elem in batch):
shape_before, shape_after = first_shape
shape_concat = sum(item.shape[dim] for item in batch)
out = torch.empty((*shape_before, shape_concat, *shape_after),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.concat(batch, dim=self.dim, out=out)
assert self.dim == 0, "dim == 0 is required for nested list"
return [e for elem in batch for e in elem]
@ -387,7 +435,12 @@ class MultiModalSharedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
return batch[0]
@ -594,11 +647,53 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property
def modality(self) -> str:
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
super().__init__(data)
modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
self._modality = next(iter(modalities))
self._is_empty = any(elem.data is None for elem in self.values())
@property
def modality(self) -> str:
return self._modality
@property
def is_empty(self) -> bool:
return self._is_empty
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
if self._is_empty:
return None
out_data = dict[str, NestedTensors]()
for key, elem in self.items():
assert elem.data is not None, (
f"Cannot get data of empty `elem[{key!r}]`")
out_data[key] = elem.data
return out_data
def require_data(self) -> Mapping[str, NestedTensors]:
if (data := self.get_data()) is None:
raise RuntimeError("Cannot get data of empty item")
return data
# These methods create a new item to avoid mutating cached items in place
def with_data(self, data: Mapping[str, NestedTensors]):
return MultiModalKwargsItem({
key: replace(elem, data=data[key])
for key, elem in self.items()
})
def without_data(self):
return MultiModalKwargsItem({
key: replace(elem, data=None)
for key, elem in self.items()
})
# NOTE: UserDict is for V0 compatibility.
@ -650,7 +745,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return MultiModalKwargs.from_items(items)
@staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]):
def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
"""Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items."""
@ -660,7 +759,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems)
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
}

View File

@ -3,6 +3,7 @@
import asyncio
import atexit
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from pathlib import Path
@ -13,6 +14,7 @@ import numpy as np
import numpy.typing as npt
import torch
from PIL import Image, UnidentifiedImageError
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
@ -23,17 +25,17 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
_M = TypeVar("_M")
if TYPE_CHECKING:
from .hasher import MultiModalHashDict
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
from .inputs import (BatchedTensorInputs, MultiModalKwargs,
MultiModalKwargsItem, MultiModalPlaceholderDict)
else:
MultiModalHashDict = Any
BatchedTensorInputs = Any
MultiModalKwargs = Any
MultiModalKwargsItem = Any
MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor(
@ -331,79 +333,32 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames)
def merge_and_sort_multimodal_metadata(
mm_positions: MultiModalPlaceholderDict,
mm_hashes: Optional[MultiModalHashDict],
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a `MultiModalHashDict` is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Returns:
list[str]: List of item modalities in order of their positions in the
input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholderRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
None otherwise.
def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
"""
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
modalities = list(mm_positions.keys())
Returns:
A list of `(modality, idx)`, which can be used to access an item
by `mm_positions[modality][idx]`.
"""
flat_items = ((modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items))
assert len(modalities) > 0, "No modalities found in the mm_positions."
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
modality = modalities[0]
placeholder_list = list(mm_positions[modality])
return [modality] * len(
placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
for modality in modalities:
placeholder_list = list(mm_positions[modality])
hash_list: list[Optional[str]] = list(
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
None
] * len(placeholder_list)
for placeholder, hash_value in zip(placeholder_list, hash_list):
all_items.append((modality, placeholder, hash_value))
# Sort all items by offset
all_items.sort(key=lambda x: x[1].offset)
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
return sorted_modalities, merged_placeholders, merged_hashes
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
together into the same list for batching purpose. For MultiModalKwargs with
multiple modalities, put them into their own list.
Args:
mm_inputs: List of MultiModalKwargs.
Returns:
list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
`MultiModalKwargs`, each inner list contains consecutive
`MultiModalKwargs` with same modality.
"""
if not mm_inputs:
return []
@ -426,6 +381,48 @@ def group_mm_inputs_by_modality(
]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
Args:
mm_inputs: List of `MultiModalKwargsItem`.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
from vllm.multimodal.inputs import MultiModalKwargs
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
# pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
# mm_kwargs_group.data)
# TODO: Once V0 is removed, we can use the merging logic above
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[MultiModalKwargs.from_items([item]) for item in items_lst],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def run_dp_sharded_vision_model(image_input: torch.Tensor,
vision_model: torch.nn.Module) -> torch.Tensor:
"""Run a vision model with data parallelism (DP) sharding. The function

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@ -24,7 +24,7 @@ class NewRequestData:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_kwargs: list[MultiModalKwargsItem]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
@ -42,7 +42,7 @@ class NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_inputs=request.mm_inputs,
mm_kwargs=request.mm_kwargs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
@ -56,7 +56,7 @@ class NewRequestData:
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs},"
f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
@ -70,7 +70,7 @@ class NewRequestData:
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs},"
f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"

View File

@ -3,15 +3,13 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
@ -49,7 +47,7 @@ class EngineCoreRequest(
request_id: str
prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_kwargs: Optional[list[MultiModalKwargsItem]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams]

View File

@ -409,12 +409,13 @@ class EngineCore:
request initialization running in parallel with Model forward
"""
if request.mm_hashes is not None:
assert request.mm_inputs is not None
assert request.mm_kwargs is not None
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes)
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from collections.abc import Mapping
from typing import TYPE_CHECKING
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -17,23 +17,23 @@ if TYPE_CHECKING:
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_inputs`,
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`.
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0),
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
@ -58,26 +58,24 @@ class MultiModalInputCacheClient:
def get_and_update(
self,
mm_inputs: Sequence[MultiModalKwargs],
mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
) -> list[MultiModalKwargsItem]:
if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
return mm_kwargs
full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
out_mm_items.append(mm_item.without_data())
else:
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_input)
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
out_mm_items.append(mm_item)
full_mm_inputs.append(mm_input)
return full_mm_inputs
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()
@ -93,30 +91,28 @@ class MultiModalInputCacheServer:
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargs,
Mapping[str, NestedTensors],
)
def get_and_update(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
) -> list[MultiModalKwargsItem]:
if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
return mm_kwargs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if (mm_data := mm_item.get_data()) is None:
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
else:
self.mm_cache[mm_hash] = mm_input
self.mm_cache[mm_hash] = mm_data
out_mm_items.append(mm_item)
full_mm_inputs.append(mm_input)
return full_mm_inputs
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping, Sequence
from collections.abc import Mapping
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
@ -10,11 +10,10 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -296,57 +295,42 @@ class Processor:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs.get("mm_hashes")
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
(
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if return_mm_hashes else None,
)
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
sorted_mm_inputs = [
decoder_mm_inputs.get_item(modality, idx)
for modality, idx in sorted_mm_idxs
]
sorted_mm_positions = [
decoder_mm_positions[modality][idx]
for modality, idx in sorted_mm_idxs
]
sorted_mm_hashes = None if decoder_mm_hashes is None else [
decoder_mm_hashes[modality][idx]
for modality, idx in sorted_mm_idxs
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
sorted_mm_inputs,
sorted_mm_hashes,
)
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs,
mm_kwargs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,

View File

@ -5,7 +5,7 @@ import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
@ -24,7 +24,7 @@ class Request:
self,
request_id: str,
prompt_token_ids: list[int],
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: Optional[SamplingParams],
@ -84,15 +84,15 @@ class Request:
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_kwargs = multi_modal_kwargs or []
self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs)
self.num_encoder_inputs = len(self.mm_kwargs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
assert len(self.mm_kwargs) == len(self.mm_positions)
if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes)
assert len(self.mm_kwargs) == len(self.mm_hashes)
# Read-only views
# Prevent directly appending to these lists since
@ -110,16 +110,15 @@ class Request:
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
return cls(
request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_kwargs=request.mm_kwargs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,

View File

@ -113,6 +113,9 @@ class MsgpackEncoder:
int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
@ -120,17 +123,12 @@ class MsgpackEncoder:
return dict(mm)
# ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored.
return [[{
"modality": elem.modality,
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]
return [
self._encode_mm_item(item)
for itemlist in mm._items_by_modality.values()
for item in itemlist
]
if isinstance(obj, UtilityResult):
result = obj.result
@ -192,6 +190,23 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
def _encode_mm_field_elem(self,
elem: MultiModalFieldElem) -> dict[str, Any]:
return {
"modality":
elem.modality,
"key":
elem.key,
"data": (None if elem.data is None else
self._encode_nested_tensors(elem.data)),
"field":
self._encode_mm_field(elem.field),
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt)
@ -250,6 +265,8 @@ class MsgpackDecoder:
return self._decode_tensor(obj)
if t is slice:
return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
@ -311,15 +328,18 @@ class MsgpackDecoder:
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj]
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj])
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
obj["data"] = self._decode_nested_tensors(obj["data"])
def _decode_mm_field_elem(self, obj: dict[str,
Any]) -> MultiModalFieldElem:
if obj["data"] is not None:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

View File

@ -7,9 +7,11 @@ from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
@ -29,7 +31,7 @@ class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
@ -51,6 +53,13 @@ class CachedRequestState:
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]

View File

@ -40,9 +40,9 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
@ -478,7 +478,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
@ -496,18 +496,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs:
for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
video_grid_thw.append(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend(
second_per_grid_ts.append(
mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend(
audio_feature_lengths.append(
mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
@ -624,14 +625,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102
if scheduler_output:
multi_modal_kwargs_list = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs)
req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_kwargs, list):
req_mm_kwargs = list(req_mm_kwargs)
mm_kwargs.extend(req_mm_kwargs)
return MultiModalKwargs.batch(multi_modal_kwargs_list)
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return mm_kwargs_combined
return {}
@ -1146,13 +1156,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -1163,17 +1173,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
)
pin_memory=self.pin_memory,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
@ -1182,11 +1187,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
expected_num_items=num_items,
)
for output in curr_group_outputs:
@ -1553,17 +1558,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
model_kwargs = {
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None
model_mm_kwargs = {}
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
@ -1596,10 +1602,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)
@ -2196,14 +2198,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_items_per_batch)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
)
return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
))
@torch.inference_mode()
def _dummy_run(
@ -2269,15 +2270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),
}
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
model_mm_kwargs = {}
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
@ -2307,10 +2310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)

View File

@ -32,9 +32,9 @@ from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
@ -394,7 +394,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=None,
@ -842,13 +842,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -859,16 +859,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
)
pin_memory=self.pin_memory,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
@ -878,12 +874,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# depending on the input multimodal items.
xm.mark_step()
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
**mm_kwargs_group)
xm.mark_step()
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
expected_num_items=num_items,
)
if isinstance(curr_group_outputs, torch.Tensor):
@ -1823,14 +1819,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_items_per_batch)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
)
return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
[dummy_mm_item] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
))
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: