From 19b927e52df8400084df1c8116af7d6f0a5f5d15 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 13 Aug 2025 22:18:07 +0800 Subject: [PATCH] [Core] Use individual MM items in P0/P1 cache and model runner (#22570) Signed-off-by: DarkLight1337 --- tests/multimodal/test_utils.py | 235 +++++++------------ tests/v1/core/test_kv_cache_utils.py | 48 ++-- tests/v1/core/test_prefix_caching.py | 31 ++- tests/v1/core/test_scheduler.py | 21 +- tests/v1/core/utils.py | 19 +- tests/v1/engine/test_engine_core.py | 2 +- tests/v1/engine/test_engine_core_client.py | 2 +- tests/v1/engine/test_output_processor.py | 10 +- tests/v1/kv_connector/unit/utils.py | 2 +- tests/v1/tpu/worker/test_tpu_model_runner.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 2 +- tests/v1/worker/test_gpu_model_runner.py | 2 +- vllm/multimodal/inputs.py | 141 +++++++++-- vllm/multimodal/utils.py | 135 ++++++----- vllm/v1/core/sched/output.py | 10 +- vllm/v1/engine/__init__.py | 6 +- vllm/v1/engine/core.py | 7 +- vllm/v1/engine/mm_input_cache.py | 78 +++--- vllm/v1/engine/processor.py | 64 ++--- vllm/v1/request.py | 21 +- vllm/v1/serial_utils.py | 48 ++-- vllm/v1/worker/gpu_input_batch.py | 13 +- vllm/v1/worker/gpu_model_runner.py | 97 ++++---- vllm/v1/worker/tpu_model_runner.py | 39 ++- 24 files changed, 549 insertions(+), 486 deletions(-) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 3fdf7e33ca5fc..41f4773a11c8d 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -5,7 +5,7 @@ import base64 import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import TYPE_CHECKING, NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple import numpy as np import pytest @@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange -from vllm.multimodal.utils import (MediaConnector, - merge_and_sort_multimodal_metadata, +from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions, run_dp_sharded_vision_model) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables if TYPE_CHECKING: - from vllm.multimodal.hasher import MultiModalHashDict from vllm.multimodal.inputs import MultiModalPlaceholderDict # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) @@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int): assert metadata_sync == metadata_async -# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. +# Used for `test_argsort_mm_positions`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" - mm_hashes: Optional["MultiModalHashDict"] - expected_modalities: list[str] - expected_ranges: list[PlaceholderRange] - expected_hashes: Optional[list[str]] + expected_modality_idxs: list[tuple[str, int]] -def test_merge_and_sort_multimodal_metadata(): +def test_argsort_mm_positions(): test_cases = [ - # Single modality should return result as is but flattened + # Single modality + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=3, length=2), ] }, - mm_hashes={"image": ["hash1", "hash2"]}, - expected_modalities=["image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=3, length=2), + expected_modality_idxs=[ + ("image", 0), + ("image", 1), ], - expected_hashes=["hash1", "hash2"], ), - - # Single modality without hashes return None for mm hash. + ## Internally unsorted TestCase( mm_positions={ "image": [ + PlaceholderRange(offset=3, length=2), PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=2), ] }, - mm_hashes=None, - expected_modalities=["image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=2), + expected_modality_idxs=[ + ("image", 1), + ("image", 0), ], - expected_hashes=None, ), - # Multiple modalities with hashes should return sorted modalities - # and flattened ranges and hashes. + # Two modalities + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=2, length=3), ] }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1", "audio_hash2"], - }, - expected_modalities=["audio", "audio", "image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), - ], - expected_hashes=[ - "audio_hash1", "audio_hash2", "image_hash1", "image_hash2" + expected_modality_idxs=[ + ("audio", 0), + ("audio", 1), + ("image", 0), + ("image", 1), ], ), - - # Multiple modalities without hashes should return sorted modalities - # and flattened ranges and None. + ## Interleaved, internally sorted TestCase( mm_positions={ "image": [ - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), + PlaceholderRange(offset=0, length=4), + PlaceholderRange(offset=8, length=2), ], "audio": [ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), + PlaceholderRange(offset=5, length=2), + PlaceholderRange(offset=11, length=4), ] }, - mm_hashes=None, - expected_modalities=["audio", "audio", "image", "image"], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=2, length=3), - PlaceholderRange(offset=7, length=4), - PlaceholderRange(offset=11, length=5), + expected_modality_idxs=[ + ("image", 0), + ("audio", 0), + ("image", 1), + ("audio", 1), + ], + ), + ## Interleaved, internally unsorted + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=8, length=2), + PlaceholderRange(offset=0, length=4), + ], + "audio": [ + PlaceholderRange(offset=11, length=4), + PlaceholderRange(offset=5, length=2), + ] + }, + expected_modality_idxs=[ + ("image", 1), + ("audio", 1), + ("image", 0), + ("audio", 0), ], - expected_hashes=None, ), # Three modalities + ## Internally sorted TestCase( mm_positions={ "image": [ @@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata(): PlaceholderRange(offset=12, length=6), ] }, - mm_hashes={ - "image": ["image_hash1", "image_hash2"], - "audio": ["audio_hash1"], - "video": ["video_hash1", "video_hash2", "video_hash3"] - }, - expected_modalities=[ - "audio", "video", "video", "video", "image", "image" - ], - expected_ranges=[ - PlaceholderRange(offset=0, length=2), - PlaceholderRange(offset=3, length=4), - PlaceholderRange(offset=7, length=5), - PlaceholderRange(offset=12, length=6), - PlaceholderRange(offset=15, length=7), - PlaceholderRange(offset=22, length=8), - ], - expected_hashes=[ - "audio_hash1", "video_hash1", "video_hash2", "video_hash3", - "image_hash1", "image_hash2" + expected_modality_idxs=[ + ("audio", 0), + ("video", 0), + ("video", 1), + ("video", 2), + ("image", 0), + ("image", 1), ], ), - ] - - for (mm_positions, mm_hashes, expected_modalities, expected_ranges, - expected_hashes) in test_cases: - modalities, ranges, hashes = merge_and_sort_multimodal_metadata( - mm_positions, mm_hashes) - - assert modalities == expected_modalities - assert ranges == expected_ranges - assert hashes == expected_hashes - - -def test_merge_and_sort_multimodal_metadata_with_interleaving(): - - test_cases = [ - - #