From d7543862bd6b5a47496a53b3c3625ad5110215f8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 2 May 2025 18:29:25 +0800 Subject: [PATCH] [Misc] Rename assets for testing (#17575) Signed-off-by: DarkLight1337 --- .../qwen2_5_omni/only_thinker.py | 4 +- examples/offline_inference/vision_language.py | 2 +- tests/conftest.py | 56 +++++++------------ .../multimodal/generation/test_common.py | 21 +++---- .../multimodal/generation/test_florence2.py | 4 +- .../generation/test_granite_speech.py | 9 +-- .../multimodal/generation/test_interleaved.py | 2 +- .../multimodal/generation/test_mllama.py | 16 +++--- .../multimodal/generation/test_qwen2_vl.py | 2 +- .../multimodal/generation/test_ultravox.py | 44 ++++++++------- .../generation/vlm_utils/builders.py | 10 ++-- .../generation/vlm_utils/model_utils.py | 8 +-- .../generation/vlm_utils/runners.py | 11 ++-- .../multimodal/generation/vlm_utils/types.py | 6 +- .../multimodal/pooling/test_intern_vit.py | 4 +- .../multimodal/processing/test_h2ovl.py | 4 +- .../multimodal/processing/test_idefics3.py | 4 +- .../multimodal/processing/test_internvl.py | 4 +- .../multimodal/processing/test_llama4.py | 4 +- .../processing/test_minimax_vl_01.py | 4 +- .../multimodal/processing/test_phi3v.py | 4 +- .../multimodal/processing/test_phi4mm.py | 4 +- .../multimodal/processing/test_qwen2_vl.py | 4 +- .../multimodal/processing/test_smolvlm.py | 4 +- tests/models/quantization/test_awq.py | 4 +- vllm/assets/audio.py | 12 +++- vllm/assets/image.py | 4 +- vllm/assets/video.py | 21 +++++-- 28 files changed, 145 insertions(+), 131 deletions(-) diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 031e924d34cf..c2c28d5ae6ae 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -47,7 +47,7 @@ def get_mixed_modalities_query() -> QueryResult: "image": ImageAsset("cherry_blossom").pil_image.convert("RGB"), "video": - VideoAsset(name="sample_demo_1", num_frames=16).np_ndarrays, + VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, }, }, limit_mm_per_prompt={ @@ -65,7 +65,7 @@ def get_use_audio_in_video_query() -> QueryResult: "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" f"{question}<|im_end|>\n" f"<|im_start|>assistant\n") - asset = VideoAsset(name="sample_demo_1", num_frames=16) + asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " "Please launch this example with " diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 6cd2a774a03d..aca11f5c50ba 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1109,7 +1109,7 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question - video = VideoAsset(name="sample_demo_1", + video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays vid_questions = ["Why is this video funny?"] diff --git a/tests/conftest.py b/tests/conftest.py index 571cca8eeccb..b1b4af86fab7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 - import json import os import tempfile -from collections import UserList from enum import Enum from typing import Any, Callable, Optional, TypedDict, TypeVar, Union @@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]: return prompts -class _ImageAssetPrompts(TypedDict): +class ImageAssetPrompts(TypedDict): stop_sign: str cherry_blossom: str -class _ImageAssetsBase(UserList[ImageAsset]): - pass - - -class _ImageAssets(_ImageAssetsBase): +class ImageTestAssets(list[ImageAsset]): def __init__(self) -> None: super().__init__([ @@ -75,7 +69,7 @@ class _ImageAssets(_ImageAssetsBase): ImageAsset("cherry_blossom"), ]) - def prompts(self, prompts: _ImageAssetPrompts) -> list[str]: + def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ Convenience method to define the prompt for each test image. @@ -85,35 +79,27 @@ class _ImageAssets(_ImageAssetsBase): return [prompts["stop_sign"], prompts["cherry_blossom"]] -class _VideoAssetPrompts(TypedDict): - sample_demo_1: str +class VideoAssetPrompts(TypedDict): + baby_reading: str -class _VideoAssetsBase(UserList[VideoAsset]): - pass - - -class _VideoAssets(_VideoAssetsBase): +class VideoTestAssets(list[VideoAsset]): def __init__(self) -> None: super().__init__([ - VideoAsset("sample_demo_1"), + VideoAsset("baby_reading"), ]) - def prompts(self, prompts: _VideoAssetPrompts) -> list[str]: - return [prompts["sample_demo_1"]] + def prompts(self, prompts: VideoAssetPrompts) -> list[str]: + return [prompts["baby_reading"]] -class _AudioAssetPrompts(TypedDict): +class AudioAssetPrompts(TypedDict): mary_had_lamb: str winning_call: str -class _AudioAssetsBase(UserList[AudioAsset]): - pass - - -class _AudioAssets(_AudioAssetsBase): +class AudioTestAssets(list[AudioAsset]): def __init__(self) -> None: super().__init__([ @@ -121,16 +107,16 @@ class _AudioAssets(_AudioAssetsBase): AudioAsset("winning_call"), ]) - def prompts(self, prompts: _AudioAssetPrompts) -> list[str]: + def prompts(self, prompts: AudioAssetPrompts) -> list[str]: return [prompts["mary_had_lamb"], prompts["winning_call"]] -IMAGE_ASSETS = _ImageAssets() -"""Singleton instance of :class:`_ImageAssets`.""" -VIDEO_ASSETS = _VideoAssets() -"""Singleton instance of :class:`_VideoAssets`.""" -AUDIO_ASSETS = _AudioAssets() -"""Singleton instance of :class:`_AudioAssets`.""" +IMAGE_ASSETS = ImageTestAssets() +"""Singleton instance of :class:`ImageTestAssets`.""" +VIDEO_ASSETS = VideoTestAssets() +"""Singleton instance of :class:`VideoTestAssets`.""" +AUDIO_ASSETS = AudioTestAssets() +"""Singleton instance of :class:`AudioTestAssets`.""" @pytest.fixture(scope="function", autouse=True) @@ -278,17 +264,17 @@ def example_long_prompts() -> list[str]: @pytest.fixture(scope="session") -def image_assets() -> _ImageAssets: +def image_assets() -> ImageTestAssets: return IMAGE_ASSETS @pytest.fixture(scope="session") -def video_assets() -> _VideoAssets: +def video_assets() -> VideoTestAssets: return VIDEO_ASSETS @pytest.fixture(scope="session") -def audio_assets() -> _AudioAssets: +def audio_assets() -> AudioTestAssets: return AUDIO_ASSETS diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index b21c80bef927..44cdd6f44aa9 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -13,8 +13,8 @@ from transformers import AutoModelForImageTextToText, AutoModelForVision2Seq from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, - _VideoAssets) +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, + VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -691,7 +691,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -716,7 +716,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -741,7 +741,7 @@ def test_image_embedding_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -763,7 +763,7 @@ def test_image_embedding_models(model_type: str, )) def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -814,7 +814,7 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -840,7 +840,7 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -866,7 +866,8 @@ def test_image_embedding_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, + monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -889,7 +890,7 @@ def test_image_embedding_models_heavy(model_type: str, def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py index 14b64393bf52..b8225f5f1243 100644 --- a/tests/models/multimodal/generation/test_florence2.py +++ b/tests/models/multimodal/generation/test_florence2.py @@ -9,7 +9,7 @@ from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner from ...utils import check_logprobs_close MODELS = ["microsoft/Florence-2-base"] @@ -118,7 +118,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, model: str, + image_assets: ImageTestAssets, model: str, size_factors: list[int], dtype: str, max_tokens: int, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index 7c14845ec54d..96c444441e3d 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -9,7 +9,8 @@ from transformers import AutoModelForSpeechSeq2Seq from vllm.lora.request import LoRARequest from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets +from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, + VllmRunner) from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -116,9 +117,9 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models(hf_runner, vllm_runner, model: str, + audio_assets: AudioTestAssets, dtype: str, max_model_len: int, + max_tokens: int, num_logprobs: int) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index 10052da9b0bd..eec84751e450 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -29,7 +29,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB") image_stop = ImageAsset("stop_sign").pil_image.convert("RGB") images = [image_cherry, image_stop] - video = VideoAsset(name="sample_demo_1", num_frames=16).np_ndarrays + video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays inputs = [ ( diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py index 1e09c8673dc3..99aa3c2d3bd9 100644 --- a/tests/models/multimodal/generation/test_mllama.py +++ b/tests/models/multimodal/generation/test_mllama.py @@ -14,8 +14,8 @@ from vllm.model_executor.models.mllama import MllamaForConditionalGeneration from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, + PromptImageInput, VllmRunner) from ....quantization.utils import is_quant_method_supported from ....utils import (create_new_process_for_each_test, large_gpu_test, multi_gpu_test) @@ -90,7 +90,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str, def _get_inputs( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, *, size_factors: Optional[list[float]] = None, sizes: Optional[list[tuple[int, int]]] = None, @@ -126,7 +126,7 @@ def _get_inputs( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, size_factors: list[float], @@ -143,7 +143,7 @@ def run_test( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, sizes: list[tuple[int, int]], @@ -159,7 +159,7 @@ def run_test( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, *, size_factors: Optional[list[float]] = None, @@ -433,7 +433,7 @@ def test_models_distributed( @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') def test_bnb_regression( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, dtype: str, max_tokens: int, @@ -473,7 +473,7 @@ def test_bnb_regression( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) def test_explicit_implicit_prompt( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model: str, dtype: str, max_tokens: int, diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index 0b27a4caf6eb..6be401b775ec 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -50,7 +50,7 @@ IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ }) VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "sample_demo_1": + "baby_reading": qwen2_vl_chat_template( VIDEO_PLACEHOLDER, "Describe this video with a short sentence ", diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index 1d7de946a3f8..322d886a593d 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -11,13 +11,22 @@ from transformers import AutoModel, AutoTokenizer from vllm.multimodal.audio import resample_audio_librosa from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner, _AudioAssets +from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +AUDIO_PROMPTS = AUDIO_ASSETS.prompts({ + "mary_had_lamb": + "Transcribe this into English.", + "winning_call": + "What is happening in this audio clip?", +}) + +MULTI_AUDIO_PROMPT = "Describe each of the audios above." + AudioTuple = tuple[np.ndarray, int] VLLM_PLACEHOLDER = "<|audio|>" @@ -31,12 +40,6 @@ CHUNKED_PREFILL_KWARGS = { } -@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) -def audio(request): - from vllm.assets.audio import AudioAsset - return AudioAsset(request.param) - - def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: """Convert kwargs to CLI args.""" args = [] @@ -53,7 +56,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def server(request, audio_assets: _AudioAssets): +def server(request, audio_assets: AudioTestAssets): args = [ "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", "--limit-mm-per-prompt", @@ -199,15 +202,19 @@ def run_multi_audio_test( pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, - num_logprobs: int, vllm_kwargs: dict) -> None: +def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets, + dtype: str, max_tokens: int, num_logprobs: int, + vllm_kwargs: dict) -> None: + audio_inputs = [( + _get_prompt(1, audio, VLLM_PLACEHOLDER), + _get_prompt(1, audio, HF_PLACEHOLDER), + audio.audio_and_sample_rate, + ) for audio in audio_assets] - vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) - hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) run_test( hf_runner, vllm_runner, - [(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)], + audio_inputs, MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -224,13 +231,12 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, - dtype: str, max_tokens: int, - num_logprobs: int, +def test_models_with_multiple_audios(vllm_runner, + audio_assets: AudioTestAssets, dtype: str, + max_tokens: int, num_logprobs: int, vllm_kwargs: dict) -> None: - vllm_prompt = _get_prompt(len(audio_assets), - "Describe each of the audios above.", + vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, @@ -245,7 +251,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, @pytest.mark.asyncio -async def test_online_serving(client, audio_assets: _AudioAssets): +async def test_online_serving(client, audio_assets: AudioTestAssets): """Exercises online serving with/without chunked prefill enabled.""" messages = [{ diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index bf5f87ebf984..e3ba955a96a6 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -11,7 +11,7 @@ from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) -from .....conftest import _ImageAssets, _VideoAssets +from .....conftest import ImageTestAssets, VideoTestAssets from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, ImageSizeWrapper, SizeType, VLMTestInfo) @@ -69,7 +69,7 @@ def get_model_prompts(base_prompts: Iterable[str], def build_single_image_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, tmp_path: Optional[PosixPath] = None): if test_info.prompt_formatter is None: @@ -116,7 +116,7 @@ def build_single_image_inputs(images, model_prompts, def build_multi_image_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, tmp_path: Optional[PosixPath] = None): if test_info.prompt_formatter is None: @@ -159,7 +159,7 @@ def build_multi_image_inputs(image_lists, model_prompts, def build_embedding_inputs_from_test_info( test_info: VLMTestInfo, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, ): # These conditions will always be true if invoked through filtering, @@ -192,7 +192,7 @@ def build_embedding_inputs_from_test_info( def build_video_inputs_from_test_info( test_info: VLMTestInfo, - video_assets: _VideoAssets, + video_assets: VideoTestAssets, size_wrapper: ImageSizeWrapper, num_frames: int, ): diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index c856fb198b32..aa9d3901fa36 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -16,7 +16,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side -from .....conftest import HfRunner, ImageAsset, _ImageAssets +from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -238,14 +238,14 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput, ####### Functions for converting image assets to embeddings -def get_llava_embeddings(image_assets: _ImageAssets): +def get_llava_embeddings(image_assets: ImageTestAssets): return [asset.image_embeds for asset in image_assets] ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], - _ImageAssets]) -> str: + tmp_path: PosixPath, prompt: str, + assets: Union[list[ImageAsset], ImageTestAssets]) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image in its diff --git a/tests/models/multimodal/generation/vlm_utils/runners.py b/tests/models/multimodal/generation/vlm_utils/runners.py index 023df5f16188..34753121ea90 100644 --- a/tests/models/multimodal/generation/vlm_utils/runners.py +++ b/tests/models/multimodal/generation/vlm_utils/runners.py @@ -4,7 +4,8 @@ types / modalities. """ from pathlib import PosixPath -from .....conftest import HfRunner, VllmRunner, _ImageAssets, _VideoAssets +from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets, + VllmRunner) from . import builders, core from .types import ExpandableVLMTestArgs, VLMTestInfo @@ -14,7 +15,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs = builders.build_single_image_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper, tmp_path) @@ -37,7 +38,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs = builders.build_multi_image_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper, tmp_path) @@ -60,7 +61,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets): + image_assets: ImageTestAssets): assert test_case.size_wrapper is not None inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info( model_test_info, image_assets, test_case.size_wrapper) @@ -86,7 +87,7 @@ def run_video_test( test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, + video_assets: VideoTestAssets, ): assert test_case.size_wrapper is not None assert test_case.num_video_frames is not None diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 1ae61ea47229..56629323394d 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -15,7 +15,7 @@ from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets +from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets from ....utils import check_logprobs_close # meta image tag; will be replaced by the appropriate tag for the model @@ -85,7 +85,7 @@ class VLMTestInfo(NamedTuple): # Function for converting ImageAssets to image embeddings; # We need to define this explicitly for embedding tests - convert_assets_to_embeddings: Optional[Callable[[_ImageAssets], + convert_assets_to_embeddings: Optional[Callable[[ImageTestAssets], torch.Tensor]] = None # Exposed options for vLLM runner; we change these in a several tests, @@ -141,7 +141,7 @@ class VLMTestInfo(NamedTuple): # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[list[ImageAsset], _ImageAssets]], + Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], str]] = None # noqa: E501 # Allows configuring a test to run with custom inputs diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index c15913b4225b..038405ded9eb 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner @@ -15,7 +15,7 @@ DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] def run_intern_vit_test( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, *, dtype: str, diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 709a686577f3..37142b6dd36f 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -11,7 +11,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import rescale_image_size from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -137,7 +137,7 @@ def _run_check( @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index f5b5cf6b5ba9..c35ce2f6ab29 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -5,7 +5,7 @@ from transformers import Idefics3Config from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -21,7 +21,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index 5ac47ecc5cc1..7ec81197a3db 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -11,7 +11,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import rescale_image_size from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -94,7 +94,7 @@ def _run_check( @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( model_id: str, - image_assets: _ImageAssets, + image_assets: ImageTestAssets, size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 2bfc2785feb6..614f17dbbeda 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -6,7 +6,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.transformers_utils.tokenizer import encode_tokens -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -17,7 +17,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) @pytest.mark.parametrize("tokenized_prompt", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict, num_imgs: int, diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py index 10de28ab54ce..9bd2b9887294 100644 --- a/tests/models/multimodal/processing/test_minimax_vl_01.py +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -7,14 +7,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import ImageSize from vllm.multimodal.processing import BaseMultiModalProcessor -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"]) @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, num_imgs: int, ): diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index ed0d04c5c5f5..b53351544c45 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -4,7 +4,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -22,7 +22,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, int], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py index 797986adba4a..c6e272650e08 100644 --- a/tests/models/multimodal/processing/test_phi4mm.py +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -4,7 +4,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -22,7 +22,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, int], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index d8c2ca414d41..02abe1ca8b02 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -4,7 +4,7 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -19,7 +19,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/multimodal/processing/test_smolvlm.py b/tests/models/multimodal/processing/test_smolvlm.py index 56edc58a71ba..224d1bcedb96 100644 --- a/tests/models/multimodal/processing/test_smolvlm.py +++ b/tests/models/multimodal/processing/test_smolvlm.py @@ -5,7 +5,7 @@ from transformers import SmolVLMConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from ....conftest import _ImageAssets +from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -21,7 +21,7 @@ from ...utils import build_model_context @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( - image_assets: _ImageAssets, + image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index c02c3d90e345..597c8e48fb64 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -7,7 +7,7 @@ import torch from vllm.multimodal.image import rescale_image_size -from ...conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets +from ...conftest import IMAGE_ASSETS, ImageTestAssets, VllmRunner from ..utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -20,7 +20,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ def run_awq_test( vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, + image_assets: ImageTestAssets, source_model: str, quant_model: str, *, diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 0203dc092a71..a21eb7f599fa 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -18,19 +18,25 @@ except ImportError: ASSET_DIR = "multimodal_asset" +AudioAssetName = Literal["winning_call", "mary_had_lamb"] + @dataclass(frozen=True) class AudioAsset: - name: Literal["winning_call", "mary_had_lamb"] + name: AudioAssetName + + @property + def filename(self) -> str: + return f"{self.name}.ogg" @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: - audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) def get_local_path(self) -> Path: - return get_vllm_public_assets(filename=f"{self.name}.ogg", + return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) @property diff --git a/vllm/assets/image.py b/vllm/assets/image.py index 2b1d258da9c7..d8cca9b74edd 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -10,10 +10,12 @@ from .base import get_vllm_public_assets VLM_IMAGES_DIR = "vision_model_images" +ImageAssetName = Literal["stop_sign", "cherry_blossom"] + @dataclass(frozen=True) class ImageAsset: - name: Literal["stop_sign", "cherry_blossom"] + name: ImageAssetName @property def pil_image(self) -> Image.Image: diff --git a/vllm/assets/video.py b/vllm/assets/video.py index fc3d47341b30..bf06746a9ff6 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Literal, Optional +from typing import ClassVar, Literal, Optional import cv2 import numpy as np @@ -76,20 +76,31 @@ def video_to_pil_images_list(path: str, ] +VideoAssetName = Literal["baby_reading"] + + @dataclass(frozen=True) class VideoAsset: - name: Literal["sample_demo_1"] + name: VideoAssetName num_frames: int = -1 + _NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = { + "baby_reading": "sample_demo_1.mp4", + } + + @property + def filename(self) -> str: + return self._NAME_TO_FILE[self.name] + @property def pil_images(self) -> list[Image.Image]: - video_path = download_video_asset(self.name + ".mp4") + video_path = download_video_asset(self.filename) ret = video_to_pil_images_list(video_path, self.num_frames) return ret @property def np_ndarrays(self) -> npt.NDArray: - video_path = download_video_asset(self.name + ".mp4") + video_path = download_video_asset(self.filename) ret = video_to_ndarrays(video_path, self.num_frames) return ret @@ -99,5 +110,5 @@ class VideoAsset: See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ - video_path = download_video_asset(self.name + ".mp4") + video_path = download_video_asset(self.filename) return librosa.load(video_path, sr=sampling_rate)[0]