[Core] Use individual MM items in P0/P1 cache and model runner (#22570)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-13 22:18:07 +08:00 committed by GitHub
parent 20d65aa755
commit 19b927e52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 549 additions and 486 deletions

View File

@ -5,7 +5,7 @@ import base64
import mimetypes import mimetypes
import os import os
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, NamedTuple, Optional from typing import TYPE_CHECKING, NamedTuple
import numpy as np import numpy as np
import pytest import pytest
@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
merge_and_sort_multimodal_metadata,
run_dp_sharded_vision_model) run_dp_sharded_vision_model)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.multimodal.inputs import MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
assert metadata_sync == metadata_async assert metadata_sync == metadata_async
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. # Used for `test_argsort_mm_positions`.
class TestCase(NamedTuple): class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict" mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"] expected_modality_idxs: list[tuple[str, int]]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
def test_merge_and_sort_multimodal_metadata(): def test_argsort_mm_positions():
test_cases = [ test_cases = [
# Single modality should return result as is but flattened # Single modality
## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=3, length=2), PlaceholderRange(offset=3, length=2),
] ]
}, },
mm_hashes={"image": ["hash1", "hash2"]}, expected_modality_idxs=[
expected_modalities=["image", "image"], ("image", 0),
expected_ranges=[ ("image", 1),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
], ],
expected_hashes=["hash1", "hash2"],
), ),
## Internally unsorted
# Single modality without hashes return None for mm hash.
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=3, length=2),
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["image", "image"], ("image", 1),
expected_ranges=[ ("image", 0),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
], ],
expected_hashes=None,
), ),
# Multiple modalities with hashes should return sorted modalities # Two modalities
# and flattened ranges and hashes. ## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=2, length=3),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("audio", 0),
"audio": ["audio_hash1", "audio_hash2"], ("audio", 1),
}, ("image", 0),
expected_modalities=["audio", "audio", "image", "image"], ("image", 1),
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
], ],
), ),
## Interleaved, internally sorted
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=7, length=4), PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=11, length=5), PlaceholderRange(offset=8, length=2),
], ],
"audio": [ "audio": [
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=11, length=4),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["audio", "audio", "image", "image"], ("image", 0),
expected_ranges=[ ("audio", 0),
PlaceholderRange(offset=0, length=2), ("image", 1),
PlaceholderRange(offset=2, length=3), ("audio", 1),
PlaceholderRange(offset=7, length=4), ],
PlaceholderRange(offset=11, length=5), ),
## Interleaved, internally unsorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=0, length=4),
],
"audio": [
PlaceholderRange(offset=11, length=4),
PlaceholderRange(offset=5, length=2),
]
},
expected_modality_idxs=[
("image", 1),
("audio", 1),
("image", 0),
("audio", 0),
], ],
expected_hashes=None,
), ),
# Three modalities # Three modalities
## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=12, length=6), PlaceholderRange(offset=12, length=6),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("audio", 0),
"audio": ["audio_hash1"], ("video", 0),
"video": ["video_hash1", "video_hash2", "video_hash3"] ("video", 1),
}, ("video", 2),
expected_modalities=[ ("image", 0),
"audio", "video", "video", "video", "image", "image" ("image", 1),
],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
], ],
), ),
] ## Interleaved, internally sorted
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
def test_merge_and_sort_multimodal_metadata_with_interleaving():
test_cases = [
# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
],
),
# <image> <image> <audio> <video> <image>
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
PlaceholderRange(offset=8, length=5), PlaceholderRange(offset=8, length=5),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["image", "image", "audio", "video", "image"], ("image", 0),
expected_ranges=[ ("image", 1),
PlaceholderRange(offset=0, length=2), ("audio", 0),
PlaceholderRange(offset=2, length=3), ("video", 0),
PlaceholderRange(offset=5, length=2), ("image", 2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
], ],
expected_hashes=None,
), ),
## Interleaved, internally sunorted
# <image> <audio> <video> <image> with hashes
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4), PlaceholderRange(offset=20, length=4),
PlaceholderRange(offset=2, length=3),
], ],
"audio": [ "audio": [
PlaceholderRange(offset=6, length=2), PlaceholderRange(offset=5, length=2),
], ],
"video": [ "video": [
PlaceholderRange(offset=10, length=5), PlaceholderRange(offset=8, length=5),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("image", 0),
"audio": ["audio_hash1"], ("image", 2),
"video": ["video_hash1"], ("audio", 0),
}, ("video", 0),
expected_modalities=["image", "audio", "video", "image"], ("image", 1),
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
], ],
), ),
] ]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges, for mm_positions, expected_modality_idxs in test_cases:
expected_hashes) in test_cases: modality_idxs = argsort_mm_positions(mm_positions)
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities assert modality_idxs == expected_modality_idxs
assert ranges == expected_ranges
assert hashes == expected_hashes
class SimpleLinearModel(torch.nn.Module): class SimpleLinearModel(torch.nn.Module):

View File

@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
from typing import Optional
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -27,20 +30,29 @@ from vllm.v1.request import Request
# yapf: enable # yapf: enable
def make_request(request_id, def make_request(
prompt_token_ids, request_id: str,
mm_positions=None, prompt_token_ids: list[int],
mm_hashes=None, mm_positions: Optional[list[PlaceholderRange]] = None,
cache_salt=None): mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None mm_kwargs = None
else: else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
def test_generate_block_hash_extra_keys(): def test_generate_block_hash_extra_keys():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(20)], prompt_token_ids=[_ for _ in range(20)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=5), PlaceholderRange(offset=0, length=5),
@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
def test_generate_block_hash_extra_keys_no_mm_inputs(): def test_generate_block_hash_extra_keys_no_mm_inputs():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
def test_generate_block_hash_extra_keys_cache_salt(): def test_generate_block_hash_extra_keys_cache_salt():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# works together with other extra keys # works together with other extra keys
request_mm = make_request( request_mm = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(20)], prompt_token_ids=[_ for _ in range(20)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=5), PlaceholderRange(offset=0, length=5),
@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn) init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
request1 = make_request( request1 = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
request2 = make_request( request2 = make_request(
request_id=1, request_id="1",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
) )
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[], prompt_token_ids=[],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,

View File

@ -9,7 +9,9 @@ import pytest
import torch import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec) KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id, def make_request(
prompt_token_ids, request_id: str,
mm_positions=None, prompt_token_ids: list[int],
mm_hashes=None, mm_positions: Optional[list[PlaceholderRange]] = None,
prompt_logprobs: Optional[int] = None, mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None): prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None mm_kwargs = None
else: else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17, sampling_params=SamplingParams(max_tokens=17,

View File

@ -8,7 +8,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
@ -1304,7 +1306,7 @@ def create_requests_with_priority(
priorities: list[int], priorities: list[int],
arrival_times: Optional[list[float]] = None, arrival_times: Optional[list[float]] = None,
num_tokens: int = 10, num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None): prompt_logprobs: Optional[int] = None):
@ -1323,16 +1325,23 @@ def create_requests_with_priority(
for i in range(num_requests): for i in range(num_requests):
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else: else:
mm_position = None mm_position = None
mm_inputs = None mm_kwargs = None
request = Request( request = Request(
request_id=f"{i}", request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request( request = Request(
request_id="0", request_id="0",
prompt_token_ids=[0, 1], prompt_token_ids=[0, 1],
multi_modal_inputs=None, multi_modal_kwargs=None,
multi_modal_hashes=None, multi_modal_hashes=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
sampling_params=sampling_params, sampling_params=sampling_params,

View File

@ -6,7 +6,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
@ -115,7 +117,7 @@ def create_scheduler(
def create_requests( def create_requests(
num_requests: int, num_requests: int,
num_tokens: int = 10, num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
@ -129,10 +131,17 @@ def create_requests(
for i in range(num_requests): for i in range(num_requests):
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else: else:
mm_position = None mm_position = None
mm_inputs = None mm_kwargs = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens) num_tokens)
request = Request( request = Request(
@ -140,7 +149,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,

View File

@ -35,7 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View File

@ -52,7 +52,7 @@ def make_request(
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids, prompt_token_ids=prompt_tokens_ids,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=params, sampling_params=params,

View File

@ -53,7 +53,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=f"request-{idx}", EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
@ -402,7 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=request_id_list[idx], EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
@ -567,7 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
@ -666,7 +666,7 @@ def test_stop_string(include_stop_str_in_output: bool,
request_id=request_id_list[idx], request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
@ -782,7 +782,7 @@ def test_iteration_stats(dummy_test_vectors):
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,

View File

@ -154,7 +154,7 @@ def create_request(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=None, multi_modal_kwargs=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,

View File

@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
mm_inputs=[], mm_kwargs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View File

@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
pooling_params=None, pooling_params=None,
mm_inputs=[], mm_kwargs=[],
mm_positions=[], mm_positions=[],
block_ids=([], ), block_ids=([], ),
generator=None, generator=None,

View File

@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
mm_inputs=[], mm_kwargs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from itertools import accumulate from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
@ -198,7 +198,7 @@ A dictionary containing nested tensors which have been batched via
""" """
@dataclass(frozen=True) @dataclass
class MultiModalFieldElem: class MultiModalFieldElem:
""" """
Represents a keyword argument corresponding to a multi-modal item Represents a keyword argument corresponding to a multi-modal item
@ -218,11 +218,14 @@ class MultiModalFieldElem:
i.e. the name of the keyword argument to be passed to the model. i.e. the name of the keyword argument to be passed to the model.
""" """
data: NestedTensors data: Optional[NestedTensors]
""" """
The tensor data of this field in The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the value of the keyword argument to be passed to the model. i.e. the value of the keyword argument to be passed to the model.
It may be set to `None` if it is determined that the item is cached
in `EngineCore`.
""" """
field: "BaseMultiModalField" field: "BaseMultiModalField"
@ -235,8 +238,15 @@ class MultiModalFieldElem:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
if self.data is None:
data_equal = other.data is None
elif other.data is None:
data_equal = self.data is None
else:
data_equal = nested_tensors_equal(self.data, other.data)
return ((self.modality, self.key) == (other.modality, other.key) return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data) and data_equal
and type(self.field) == type(other.field)) # noqa: E721 and type(self.field) == type(other.field)) # noqa: E721
@ -280,10 +290,20 @@ class BaseMultiModalField(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
raise NotImplementedError raise NotImplementedError
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: def reduce_data(
self,
elems: list[MultiModalFieldElem],
*,
pin_memory: bool = False,
) -> NestedTensors:
""" """
Merge the data from multiple instances of Merge the data from multiple instances of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]. [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
@ -295,7 +315,13 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1: if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}") raise ValueError(f"Cannot merge different {field_types=}")
return self._reduce_data([item.data for item in elems]) validated_data = list[NestedTensors]()
for i, elem in enumerate(elems):
assert elem.data is not None, (
f"Cannot merge with empty `elems[{i}]`")
validated_data.append(elem.data)
return self._reduce_data(validated_data, pin_memory=pin_memory)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -314,7 +340,12 @@ class MultiModalBatchedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data] return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
@ -323,7 +354,11 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch[0].unsqueeze(0).contiguous() return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch): if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch) out = torch.empty((len(batch), *batch[0].shape),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.stack(batch, out=out)
return batch return batch
@ -350,7 +385,12 @@ class MultiModalFlatField(BaseMultiModalField):
"torch.Tensor is required for multiple slices" "torch.Tensor is required for multiple slices"
return [field_factory(data[cast(slice, s)]) for s in self.slices] return [field_factory(data[cast(slice, s)]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
@ -358,13 +398,21 @@ class MultiModalFlatField(BaseMultiModalField):
# - will achieve zero-copy if the tensor is contiguous # - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous() return batch[0].contiguous()
def _expect_same_shape(tensor: torch.Tensor): dim = self.dim + (self.dim < 0) * len(batch[0].shape)
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
first_shape = _expect_same_shape(batch[0]) def _shape_before_after(tensor: torch.Tensor):
return tensor.shape[:dim], tensor.shape[dim + 1:]
if all(_expect_same_shape(elem) == first_shape for elem in batch): first_shape = _shape_before_after(batch[0])
return torch.concat(batch, dim=self.dim)
if all(_shape_before_after(elem) == first_shape for elem in batch):
shape_before, shape_after = first_shape
shape_concat = sum(item.shape[dim] for item in batch)
out = torch.empty((*shape_before, shape_concat, *shape_after),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.concat(batch, dim=self.dim, out=out)
assert self.dim == 0, "dim == 0 is required for nested list" assert self.dim == 0, "dim == 0 is required for nested list"
return [e for elem in batch for e in elem] return [e for elem in batch for e in elem]
@ -387,7 +435,12 @@ class MultiModalSharedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
return batch[0] return batch[0]
@ -594,11 +647,53 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def from_elems(elems: Sequence[MultiModalFieldElem]): def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems}) return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
def modality(self) -> str: super().__init__(data)
modalities = {elem.modality for elem in self.data.values()} modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}" assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities)) self._modality = next(iter(modalities))
self._is_empty = any(elem.data is None for elem in self.values())
@property
def modality(self) -> str:
return self._modality
@property
def is_empty(self) -> bool:
return self._is_empty
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
if self._is_empty:
return None
out_data = dict[str, NestedTensors]()
for key, elem in self.items():
assert elem.data is not None, (
f"Cannot get data of empty `elem[{key!r}]`")
out_data[key] = elem.data
return out_data
def require_data(self) -> Mapping[str, NestedTensors]:
if (data := self.get_data()) is None:
raise RuntimeError("Cannot get data of empty item")
return data
# These methods create a new item to avoid mutating cached items in place
def with_data(self, data: Mapping[str, NestedTensors]):
return MultiModalKwargsItem({
key: replace(elem, data=data[key])
for key, elem in self.items()
})
def without_data(self):
return MultiModalKwargsItem({
key: replace(elem, data=None)
for key, elem in self.items()
})
# NOTE: UserDict is for V0 compatibility. # NOTE: UserDict is for V0 compatibility.
@ -650,7 +745,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return MultiModalKwargs.from_items(items) return MultiModalKwargs.from_items(items)
@staticmethod @staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]): def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
"""Construct a new """Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items.""" from multiple items."""
@ -660,7 +759,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
data = { data = {
key: elems[0].field.reduce_data(elems) key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0 for key, elems in elems_by_key.items() if len(elems) > 0
} }

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import atexit import atexit
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
@ -13,6 +14,7 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
@ -23,17 +25,17 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO from .video import VideoMediaIO
_M = TypeVar("_M") _M = TypeVar("_M")
if TYPE_CHECKING: if TYPE_CHECKING:
from .hasher import MultiModalHashDict from .inputs import (BatchedTensorInputs, MultiModalKwargs,
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict MultiModalKwargsItem, MultiModalPlaceholderDict)
else: else:
MultiModalHashDict = Any BatchedTensorInputs = Any
MultiModalKwargs = Any MultiModalKwargs = Any
MultiModalKwargsItem = Any
MultiModalPlaceholderDict = Any MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor( global_thread_pool = ThreadPoolExecutor(
@ -331,79 +333,32 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames) return video_io.encode_base64(frames)
def merge_and_sort_multimodal_metadata( def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict, mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
mm_hashes: Optional[MultiModalHashDict],
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
objects from all available modalities into a single list of
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a `MultiModalHashDict` is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Returns:
list[str]: List of item modalities in order of their positions in the
input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholderRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
None otherwise.
""" """
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
modalities = list(mm_positions.keys()) Returns:
A list of `(modality, idx)`, which can be used to access an item
by `mm_positions[modality][idx]`.
"""
flat_items = ((modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items))
assert len(modalities) > 0, "No modalities found in the mm_positions." sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
# For single modality, placeholder ranges and hashes are already sorted return [(modality, idx) for modality, idx, _ in sorted_flat_items]
# so we can return the list directly.
if len(modalities) == 1:
modality = modalities[0]
placeholder_list = list(mm_positions[modality])
return [modality] * len(
placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
for modality in modalities:
placeholder_list = list(mm_positions[modality])
hash_list: list[Optional[str]] = list(
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
None
] * len(placeholder_list)
for placeholder, hash_value in zip(placeholder_list, hash_list):
all_items.append((modality, placeholder, hash_value))
# Sort all items by offset
all_items.sort(key=lambda x: x[1].offset)
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
return sorted_modalities, merged_placeholders, merged_hashes
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality( def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
together into the same list for batching purpose. For MultiModalKwargs with
multiple modalities, put them into their own list.
Args:
mm_inputs: List of MultiModalKwargs.
Returns:
list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
`MultiModalKwargs`, each inner list contains consecutive
`MultiModalKwargs` with same modality.
"""
if not mm_inputs: if not mm_inputs:
return [] return []
@ -426,6 +381,48 @@ def group_mm_inputs_by_modality(
] ]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
Args:
mm_inputs: List of `MultiModalKwargsItem`.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
from vllm.multimodal.inputs import MultiModalKwargs
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
# pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
# mm_kwargs_group.data)
# TODO: Once V0 is removed, we can use the merging logic above
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[MultiModalKwargs.from_items([item]) for item in items_lst],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def run_dp_sharded_vision_model(image_input: torch.Tensor, def run_dp_sharded_vision_model(image_input: torch.Tensor,
vision_model: torch.nn.Module) -> torch.Tensor: vision_model: torch.nn.Module) -> torch.Tensor:
"""Run a vision model with data parallelism (DP) sharding. The function """Run a vision model with data parallelism (DP) sharding. The function

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata) KVConnectorMetadata)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
@ -24,7 +24,7 @@ class NewRequestData:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs] mm_kwargs: list[MultiModalKwargsItem]
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
@ -42,7 +42,7 @@ class NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
mm_inputs=request.mm_inputs, mm_kwargs=request.mm_kwargs,
mm_hashes=request.mm_hashes, mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions, mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
@ -56,7 +56,7 @@ class NewRequestData:
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids}," f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs}," f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes}," f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions}," f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
@ -70,7 +70,7 @@ class NewRequestData:
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)}," f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs}," f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes}," f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions}," f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"

View File

@ -3,15 +3,13 @@
import enum import enum
import time import time
from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import msgspec import msgspec
import torch import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
@ -49,7 +47,7 @@ class EngineCoreRequest(
request_id: str request_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_kwargs: Optional[list[MultiModalKwargsItem]]
mm_hashes: Optional[list[str]] mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]] mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]

View File

@ -409,12 +409,13 @@ class EngineCore:
request initialization running in parallel with Model forward request initialization running in parallel with Model forward
""" """
if request.mm_hashes is not None: if request.mm_hashes is not None:
assert request.mm_inputs is not None assert request.mm_kwargs is not None
# Note on thread safety: no race condition. # Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init, # `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards. # and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes) request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
if req.use_structured_output: if req.use_structured_output:

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -17,23 +17,23 @@ if TYPE_CHECKING:
# -- P0: # -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of # - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image), # each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_inputs`, # - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an # which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item. # input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding # - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size # `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking # of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0. # up additional memory in P0.
# - The `mm_hash` is always sent to P1. # - The `mm_hash` is always sent to P1.
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached # - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer. # in MultiModalInputCacheServer.
# #
# -- P1: # -- P1:
# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0), # - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`. # MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0), # - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`. # MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to # - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution. # the engine for model execution.
# #
# Both Client and Server must perform cache update and eviction based on the # Both Client and Server must perform cache update and eviction based on the
@ -58,26 +58,24 @@ class MultiModalInputCacheClient:
def get_and_update( def get_and_update(
self, self,
mm_inputs: Sequence[MultiModalKwargs], mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]: ) -> list[MultiModalKwargsItem]:
assert len(mm_inputs) == len(mm_hashes)
if not self.enabled: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) return mm_kwargs
return mm_inputs
full_mm_inputs = list[Optional[MultiModalKwargs]]() assert len(mm_kwargs) == len(mm_hashes)
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None: if self.mm_cache.get(mm_hash) is not None:
mm_input = None out_mm_items.append(mm_item.without_data())
else: else:
self.mm_cache[mm_hash] = \ self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_input) MultiModalCacheItemMetadata.wraps(mm_item.require_data())
out_mm_items.append(mm_item)
full_mm_inputs.append(mm_input) return out_mm_items
return full_mm_inputs
def reset(self) -> None: def reset(self) -> None:
self.mm_cache.clear() self.mm_cache.clear()
@ -93,30 +91,28 @@ class MultiModalInputCacheServer:
self.enabled = mm_registry.enable_mm_input_cache(model_config) self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache( self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(), model_config.get_mm_input_cache_gb(),
MultiModalKwargs, Mapping[str, NestedTensors],
) )
def get_and_update( def get_and_update(
self, self,
mm_inputs: Sequence[Optional[MultiModalKwargs]], mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]: ) -> list[MultiModalKwargsItem]:
assert len(mm_inputs) == len(mm_hashes)
if not self.enabled: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) return mm_kwargs
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]() assert len(mm_kwargs) == len(mm_hashes)
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None: out_mm_items = list[MultiModalKwargsItem]()
mm_input = self.mm_cache[mm_hash] for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if (mm_data := mm_item.get_data()) is None:
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
else: else:
self.mm_cache[mm_hash] = mm_input self.mm_cache[mm_hash] = mm_data
out_mm_items.append(mm_item)
full_mm_inputs.append(mm_input) return out_mm_items
return full_mm_inputs
def reset(self) -> None: def reset(self) -> None:
self.mm_cache.clear() self.mm_cache.clear()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -10,11 +10,10 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
MultiModalRegistry) from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -296,57 +295,42 @@ class Processor:
pooling_params = params.clone() pooling_params = params.clone()
# Multimodal related. # Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal": if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs.get("mm_hashes")
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
( sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if return_mm_hashes else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`) sorted_mm_inputs = [
# is a single MultiModalKwargs for all items from all modalities. decoder_mm_inputs.get_item(modality, idx)
# This code flattens kwargs for individual items in a list and for modality, idx in sorted_mm_idxs
# sorts them by each item's position in the input sequence if there ]
# are multiple modalities. sorted_mm_positions = [
unique_modalities = set(sorted_item_modalities) decoder_mm_positions[modality][idx]
if len(unique_modalities) > 1: for modality, idx in sorted_mm_idxs
orig_sorted_mm_inputs = [] ]
used_indices = {modality: 0 for modality in unique_modalities} sorted_mm_hashes = None if decoder_mm_hashes is None else [
decoder_mm_hashes[modality][idx]
for modality in sorted_item_modalities: for modality, idx in sorted_mm_idxs
items = decoder_mm_inputs.get_items(modality) ]
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None: if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update( sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs, sorted_mm_hashes) sorted_mm_inputs,
else: sorted_mm_hashes,
sorted_mm_inputs = orig_sorted_mm_inputs )
return decoder_inputs.get("prompt"), EngineCoreRequest( return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"], prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs, mm_kwargs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions, mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,

View File

@ -5,7 +5,7 @@ import enum
import time import time
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -24,7 +24,7 @@ class Request:
self, self,
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
multi_modal_hashes: Optional[list[str]], multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],
@ -84,15 +84,15 @@ class Request:
# Multi-modal related # Multi-modal related
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_kwargs = multi_modal_kwargs or []
self.mm_hashes: list[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_kwargs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_kwargs) == len(self.mm_positions)
if self.mm_hashes: if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes) assert len(self.mm_kwargs) == len(self.mm_hashes)
# Read-only views # Read-only views
# Prevent directly appending to these lists since # Prevent directly appending to these lists since
@ -110,16 +110,15 @@ class Request:
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None: if request.mm_kwargs is not None:
assert isinstance(request.mm_inputs, list) assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
assert is_list_of(request.mm_inputs, MultiModalKwargs), ( "mm_kwargs was not updated in EngineCore.add_request")
"mm_inputs was not updated in EngineCore.add_request")
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index, client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs, multi_modal_kwargs=request.mm_kwargs,
multi_modal_hashes=request.mm_hashes, multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,

View File

@ -113,6 +113,9 @@ class MsgpackEncoder:
int(v) if v is not None else None int(v) if v is not None else None
for v in (obj.start, obj.stop, obj.step)) for v in (obj.start, obj.stop, obj.step))
if isinstance(obj, MultiModalKwargsItem):
return self._encode_mm_item(obj)
if isinstance(obj, MultiModalKwargs): if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj mm: MultiModalKwargs = obj
if not mm.modalities: if not mm.modalities:
@ -120,17 +123,12 @@ class MsgpackEncoder:
return dict(mm) return dict(mm)
# ignore the main dict, it will be re-indexed. # ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored. # Any tensors *not* indexed by modality will be ignored.
return [[{ return [
"modality": elem.modality, self._encode_mm_item(item)
"key": elem.key, for itemlist in mm._items_by_modality.values()
"data": self._encode_nested_tensors(elem.data), for item in itemlist
"field": self._encode_mm_field(elem.field), ]
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]
if isinstance(obj, UtilityResult): if isinstance(obj, UtilityResult):
result = obj.result result = obj.result
@ -192,6 +190,23 @@ class MsgpackEncoder:
dtype = str(obj.dtype).removeprefix("torch.") dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
def _encode_mm_item(self,
item: MultiModalKwargsItem) -> list[dict[str, Any]]:
return [self._encode_mm_field_elem(elem) for elem in item.values()]
def _encode_mm_field_elem(self,
elem: MultiModalFieldElem) -> dict[str, Any]:
return {
"modality":
elem.modality,
"key":
elem.key,
"data": (None if elem.data is None else
self._encode_nested_tensors(elem.data)),
"field":
self._encode_mm_field(elem.field),
}
def _encode_nested_tensors(self, nt: NestedTensors) -> Any: def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor): if isinstance(nt, torch.Tensor):
return self._encode_tensor(nt) return self._encode_tensor(nt)
@ -250,6 +265,8 @@ class MsgpackDecoder:
return self._decode_tensor(obj) return self._decode_tensor(obj)
if t is slice: if t is slice:
return slice(*obj) return slice(*obj)
if issubclass(t, MultiModalKwargsItem):
return self._decode_mm_item(obj)
if issubclass(t, MultiModalKwargs): if issubclass(t, MultiModalKwargs):
if isinstance(obj, list): if isinstance(obj, list):
return MultiModalKwargs.from_items( return MultiModalKwargs.from_items(
@ -311,15 +328,18 @@ class MsgpackDecoder:
# Convert back to proper shape & type # Convert back to proper shape & type
return arr.view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
return [self._decode_mm_item(v) for v in obj] return [self._decode_mm_item(v) for v in obj]
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem: def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems( return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj]) [self._decode_mm_field_elem(v) for v in obj])
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem: def _decode_mm_field_elem(self, obj: dict[str,
obj["data"] = self._decode_nested_tensors(obj["data"]) Any]) -> MultiModalFieldElem:
if obj["data"] is not None:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig # Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"] factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

View File

@ -7,9 +7,11 @@ from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
@ -29,7 +31,7 @@ class CachedRequestState:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs] mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
@ -51,6 +53,13 @@ class CachedRequestState:
def num_tokens(self) -> int: def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids) return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargs]:
return [MultiModalKwargs.from_items([item]) for item in self.mm_kwargs]
def get_token_id(self, idx: int) -> int: def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx] return self.prompt_token_ids[idx]

View File

@ -40,9 +40,9 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model) VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
@ -478,7 +478,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
@ -496,18 +496,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = [] second_per_grid_ts = []
audio_feature_lengths = [] audio_feature_lengths = []
use_audio_in_video = False use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs: for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
if mm_input.get("image_grid_thw") is not None: if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend( image_grid_thw.append(
mm_input["image_grid_thw"].tolist()) mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None: if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend( video_grid_thw.append(
mm_input["video_grid_thw"].tolist()) mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None: if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend( second_per_grid_ts.append(
mm_input["second_per_grid_ts"]) mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None: if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend( audio_feature_lengths.append(
mm_input["audio_feature_lengths"]) mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True: if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True use_audio_in_video = True
@ -624,14 +625,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported: # noqa: SIM102 if self.is_multimodal_raw_input_supported: # noqa: SIM102
if scheduler_output: if scheduler_output:
multi_modal_kwargs_list = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs: for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs req_mm_kwargs = req.mm_kwargs
if not isinstance(req_mm_inputs, list): if not isinstance(req_mm_kwargs, list):
req_mm_inputs = list(req_mm_inputs) req_mm_kwargs = list(req_mm_kwargs)
multi_modal_kwargs_list.extend(req_mm_inputs) mm_kwargs.extend(req_mm_kwargs)
return MultiModalKwargs.batch(multi_modal_kwargs_list) # Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
mm_kwargs_combined.update(mm_kwargs_group)
return mm_kwargs_combined
return {} return {}
@ -1146,13 +1156,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -1163,17 +1173,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
batched_mm_inputs = MultiModalKwargs.batch( mm_kwargs,
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size) # 1. A tensor of shape (num_items, feature_size, hidden_size)
@ -1182,11 +1187,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic # (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items. # depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
expected_num_items=len(grouped_mm_inputs), expected_num_items=num_items,
) )
for output in curr_group_outputs: for output in curr_group_outputs:
@ -1553,17 +1558,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) model_kwargs = {
model_kwargs = self._init_model_kwargs(num_scheduled_tokens) **self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the # While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since # multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None inputs_embeds = None
model_mm_kwargs = {} model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens] positions = self.mrope_positions[:, :num_input_tokens]
else: else:
@ -1596,10 +1602,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs, **model_kwargs,
) )
@ -2196,14 +2198,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * return next(mm_kwargs_group
max_items_per_batch) for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
return MultiModalKwargs.as_kwargs( [dummy_mm_item] * max_items_per_batch,
batched_dummy_mm_inputs, device=self.device,
device=self.device, pin_memory=self.pin_memory,
) ))
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
@ -2269,15 +2270,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) model_kwargs = {
**self._init_model_kwargs(num_tokens),
**self._dummy_mm_kwargs(num_reqs),
}
else: else:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_mm_kwargs = {} model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens] positions = self.mrope_positions[:, :num_tokens]
@ -2307,10 +2310,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_mm_kwargs,
device=self.device,
),
**model_kwargs, **model_kwargs,
) )

View File

@ -32,9 +32,9 @@ from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model) is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange) PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
@ -394,7 +394,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs, mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions, mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
@ -842,13 +842,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]() mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id]) mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append( req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id])) (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -859,16 +859,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) mm_kwargs,
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device, device=self.device,
) pin_memory=self.pin_memory,
):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size) # 1. A tensor of shape (num_items, feature_size, hidden_size)
@ -878,12 +874,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# depending on the input multimodal items. # depending on the input multimodal items.
xm.mark_step() xm.mark_step()
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs) **mm_kwargs_group)
xm.mark_step() xm.mark_step()
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
expected_num_items=len(grouped_mm_inputs), expected_num_items=num_items,
) )
if isinstance(curr_group_outputs, torch.Tensor): if isinstance(curr_group_outputs, torch.Tensor):
@ -1823,14 +1819,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Result in the maximum GPU consumption of the model # Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * return next(grouped_mm_kwargs
max_items_per_batch) for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
return MultiModalKwargs.as_kwargs( [dummy_mm_item] * max_items_per_batch,
batched_dummy_mm_inputs, device=self.device,
device=self.device, pin_memory=self.pin_memory,
) ))
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: