diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index dead2edc4fa3..d51a03dfea7e 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -8,14 +8,14 @@ from collections import defaultdict from pathlib import PosixPath import pytest -from transformers import (AutoModelForImageTextToText, +from transformers import (AutoModel, AutoModelForImageTextToText, AutoModelForTextToWaveform, AutoModelForVision2Seq) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - VideoTestAssets, VllmRunner) +from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, + ImageTestAssets, VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -158,6 +158,17 @@ VLM_TEST_SETTINGS = { image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "ultravox": VLMTestInfo( + models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + test_type=VLMTestType.AUDIO, + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + audio_idx_to_prompt=lambda idx: "<|audio|>", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModel, + hf_output_post_proc=model_utils.ultravox_trunc_hf_output, + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -393,7 +404,6 @@ VLM_TEST_SETTINGS = { formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 ), limit_mm_per_prompt={"video": 4}, - runner_mm_key="videos", )], ), "llava_next_video": VLMTestInfo( @@ -706,6 +716,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) # - multi-image # - image embeddings # - video +# - audio # - custom inputs @pytest.mark.parametrize( "model_type,test_case", @@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=False, + )) +def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( @@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=True, + )) +def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index 322d886a593d..2c8a06688ca0 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -1,20 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Any, Optional +from typing import Any import numpy as np import pytest import pytest_asyncio -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer -from vllm.multimodal.audio import resample_audio_librosa -from vllm.sequence import SampleLogprobs - -from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner +from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS -from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder): add_generation_prompt=True) -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = output_ids[:] - hf_output_str = output_str - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts_and_audios: list[tuple[str, str, AudioTuple]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - **kwargs, -): - """Inference result should be the same between hf and vllm.""" - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip") - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner(model, dtype=dtype, enforce_eager=True, - **kwargs) as vllm_model: - vllm_outputs_per_audio = [ - vllm_model.generate_greedy_logprobs([vllm_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[audio]) - for vllm_prompt, _, audio in prompts_and_audios - ] - - with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: - hf_outputs_per_audio = [ - hf_model.generate_greedy_logprobs_limit( - [hf_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[(resample_audio_librosa(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) - for _, hf_prompt, audio in prompts_and_audios - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio, - vllm_outputs_per_audio): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - def run_multi_audio_test( vllm_runner: type[VllmRunner], prompts_and_audios: list[tuple[str, list[AudioTuple]]], @@ -194,35 +117,6 @@ def run_multi_audio_test( assert all(tokens for tokens, *_ in vllm_outputs) -@pytest.mark.core_model -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets, - dtype: str, max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - audio_inputs = [( - _get_prompt(1, audio, VLLM_PLACEHOLDER), - _get_prompt(1, audio, HF_PLACEHOLDER), - audio.audio_and_sample_rate, - ) for audio in audio_assets] - - run_test( - hf_runner, - vllm_runner, - audio_inputs, - MODEL_NAME, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - **vllm_kwargs, - ) - - @pytest.mark.core_model @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index e3ba955a96a6..32117c8d8dca 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -7,18 +7,21 @@ from typing import Callable, Optional, Union import torch +from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) -from .....conftest import ImageTestAssets, VideoTestAssets -from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, +from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets +from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, SizeType, VLMTestInfo) + ImageSizeWrapper, PromptWithMultiModalInput, SizeType, + VLMTestInfo) -def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], - str], +def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], + str], test_placeholder: str) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. @@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], prompt_segments = prompt.split(test_placeholder) img_prompt = prompt_segments[0] for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1): - img_prompt += img_idx_to_prompt(placeholder_idx) + img_prompt += mm_idx_to_prompt(placeholder_idx) img_prompt += next_seg return img_prompt @@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], def get_model_prompts(base_prompts: Iterable[str], img_idx_to_prompt: Optional[Callable[[int], str]], video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], prompt_formatter: Callable[[str], str]) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting @@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str], video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER) + if audio_idx_to_prompt: + base_prompt = replace_test_placeholder(base_prompt, + audio_idx_to_prompt, + TEST_AUDIO_PLACEHOLDER) + # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt model_prompt = prompt_formatter(base_prompt) @@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str], def build_single_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build single image inputs") @@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info( model_prompts = get_model_prompts(test_info.single_image_prompts, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) # For models that require a local path / URL encoded in the image; export @@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info( return build_single_image_inputs(images, model_prompts, size_wrapper) -def build_single_image_inputs(images, model_prompts, - size_wrapper: ImageSizeWrapper): +def build_single_image_inputs( + images, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being # scaled by one of the size factors. # # NOTE: rescaling preserves the image aspect ratio. - return [( - [prompt for _ in size_wrapper.data], - [ - apply_image_size_scaling(image, size, size_wrapper.type) - for size in size_wrapper.data - ], - ) for image, prompt in zip(images, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[ + apply_image_size_scaling(image, size, size_wrapper.type) + for size in size_wrapper.data + ], + ) for image, prompt in zip(images, model_prompts) + ] def build_multi_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build multi image inputs") @@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info( model_prompts = get_model_prompts([test_info.multi_image_prompt], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) if test_info.prompt_path_encoder is not None: @@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info( ) -def build_multi_image_inputs(image_lists, model_prompts, - size_wrapper: ImageSizeWrapper): - return [( - [prompt for _ in size_wrapper.data], - [[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts)] +def build_multi_image_inputs( + image_lists, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[[ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] for size in size_wrapper.data], + ) for images, prompt in zip(image_lists, model_prompts) + ] def build_embedding_inputs_from_test_info( @@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info( SINGLE_IMAGE_BASE_PROMPTS, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -195,13 +215,14 @@ def build_video_inputs_from_test_info( video_assets: VideoTestAssets, size_wrapper: ImageSizeWrapper, num_frames: int, -): +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError("Prompt formatter must be set to build video inputs") model_prompts = get_model_prompts( [VIDEO_BASE_PROMPT], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -213,10 +234,14 @@ def build_video_inputs_from_test_info( video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size) - return [( - [prompt for _ in size_wrapper.data], - [video_scaler(video, size) for size in size_wrapper.data], - ) for video, prompt in zip(sampled_vids, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + video_data=[ + video_scaler(video, size) for size in size_wrapper.data + ], + ) for video, prompt in zip(sampled_vids, model_prompts) + ] def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], @@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], # We have a list of fixed sizes return image.resize(size) raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR") + + +def build_audio_inputs_from_test_info( + test_info: VLMTestInfo, + audio_assets: AudioTestAssets, +) -> list[PromptWithMultiModalInput]: + if test_info.prompt_formatter is None: + raise ValueError("Prompt formatter must be set to build audio inputs") + model_prompts = get_model_prompts( + SINGLE_AUDIO_BASE_PROMPT, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) + resampler = AudioResampler( + target_sr=16000, + method="librosa", + ) + audios = [asset.audio_and_sample_rate for asset in audio_assets] + resampled_audios = [( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) for audio, sr in audios] + + return [ + PromptWithMultiModalInput( + prompts=model_prompts, + audio_data=resampled_audios, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 8e825676b8f4..a5077a090b52 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -83,7 +83,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo], test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided - if test_type != VLMTestType.CUSTOM_INPUTS: + if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: raise ValueError( @@ -91,7 +91,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo], iter_kwargs["size_wrapper"] = wrapped_sizes #Otherwise expand the custom test options instead - else: + elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") iter_kwargs["custom_test_opts"] = test_info.custom_test_opts @@ -136,8 +136,8 @@ def get_wrapped_test_sizes( ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) for factor in EMBEDDING_SIZE_FACTORS ]) - # Custom inputs have preprocessed inputs - elif test_type == VLMTestType.CUSTOM_INPUTS: + # Audio and Custom inputs have preprocessed inputs + elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() size_factors = test_info.image_size_factors \ diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index c3d20f56855f..ccd2799abd90 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Core test implementation to be shared across modalities.""" -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import torch -from PIL.Image import Image from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption @@ -11,14 +10,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import HfRunner, VllmRunner from ....registry import HF_EXAMPLE_MODELS -from .types import RunnerOutput +from .types import PromptWithMultiModalInput, RunnerOutput def run_test( *, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], list[Union[list[Image], Image]]]], + inputs: list[PromptWithMultiModalInput], model: str, dtype: str, max_tokens: int, @@ -38,7 +37,6 @@ def run_test( hf_model_kwargs: Optional[dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], task: TaskOption = "auto", - runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, @@ -94,10 +92,16 @@ def run_test( if stop_str: vllm_kwargs["stop"] = stop_str - for prompts, media in vllm_inputs: - vllm_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in vllm_inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs) + prompts, + max_tokens, + num_logprobs=num_logprobs, + **vllm_kwargs_with_mm_data) vllm_outputs_per_mm.append(vllm_output) hf_model = hf_runner(model, @@ -122,14 +126,17 @@ def run_test( if stop_str: hf_kwargs["stop_strings"] = stop_str - for prompts, media in inputs: - hf_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs) + **hf_kwargs_with_mm_data) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 235618ae547e..cc1045561138 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -12,7 +12,7 @@ from vllm.multimodal.video import (rescale_video_size, resize_video, from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs -from .types import ImageSizeWrapper, SizeType +from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): @@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): "\nWhat is the season?", ] formatted_prompts = [formatter(prompt) for prompt in img_prompts] - - return [( - formatted_prompts, + aspect_ratio_images = [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios [ - [stop_sign, cherry_blossom], - # Images with different sizes and aspect-ratios - [ - rescale_image_size(stop_sign, 0.1), - stop_sign, - ], - [ - stop_sign, - rescale_image_size(stop_sign, 0.25), - cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) - ], - cherry_blossom, - ])] + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ] + + return [ + PromptWithMultiModalInput( + prompts=formatted_prompts, + image_data=aspect_ratio_images, + ) + ] def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], @@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], "