[Misc] Consolidate Audio tests into multimodal common generation tests (#18214)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-16 17:18:08 +08:00 committed by GitHub
parent 541817670c
commit 390ec88905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 282 additions and 215 deletions

View File

@ -8,14 +8,14 @@ from collections import defaultdict
from pathlib import PosixPath from pathlib import PosixPath
import pytest import pytest
from transformers import (AutoModelForImageTextToText, from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq) AutoModelForTextToWaveform, AutoModelForVision2Seq)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import identity from vllm.utils import identity
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner,
VideoTestAssets, VllmRunner) ImageTestAssets, VideoTestAssets, VllmRunner)
from ....utils import (create_new_process_for_each_test, large_gpu_mark, from ....utils import (create_new_process_for_each_test, large_gpu_mark,
multi_gpu_marks) multi_gpu_marks)
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -158,6 +158,17 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
"ultravox": VLMTestInfo(
models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
test_type=VLMTestType.AUDIO,
prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
audio_idx_to_prompt=lambda idx: "<|audio|>",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModel,
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests #### Extended model tests
"aria": VLMTestInfo( "aria": VLMTestInfo(
models=["rhymes-ai/Aria"], models=["rhymes-ai/Aria"],
@ -393,7 +404,6 @@ VLM_TEST_SETTINGS = {
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
), ),
limit_mm_per_prompt={"video": 4}, limit_mm_per_prompt={"video": 4},
runner_mm_key="videos",
)], )],
), ),
"llava_next_video": VLMTestInfo( "llava_next_video": VLMTestInfo(
@ -706,6 +716,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
# - multi-image # - multi-image
# - image embeddings # - image embeddings
# - video # - video
# - audio
# - custom inputs # - custom inputs
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_type,test_case", "model_type,test_case",
@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
) )
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(
VLM_TEST_SETTINGS,
test_type=VLMTestType.AUDIO,
create_new_process_for_each_test=False,
))
def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
runners.run_audio_test(
model_test_info=model_test_info,
test_case=test_case,
hf_runner=hf_runner,
vllm_runner=vllm_runner,
audio_assets=audio_assets,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_type,test_case", "model_type,test_case",
get_parametrized_options( get_parametrized_options(
@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
) )
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(
VLM_TEST_SETTINGS,
test_type=VLMTestType.AUDIO,
create_new_process_for_each_test=True,
))
def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
runners.run_audio_test(
model_test_info=model_test_info,
test_case=test_case,
hf_runner=hf_runner,
vllm_runner=vllm_runner,
audio_assets=audio_assets,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_type,test_case", "model_type,test_case",
get_parametrized_options( get_parametrized_options(

View File

@ -1,20 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
from typing import Any, Optional from typing import Any
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from transformers import AutoModel, AutoTokenizer from transformers import AutoTokenizer
from vllm.multimodal.audio import resample_audio_librosa from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
from vllm.sequence import SampleLogprobs
from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder):
add_generation_prompt=True) add_generation_prompt=True)
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = output_ids[:]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
prompts_and_audios: list[tuple[str, str, AudioTuple]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, dtype=dtype, enforce_eager=True,
**kwargs) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[audio])
for vllm_prompt, _, audio in prompts_and_audios
]
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(resample_audio_librosa(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
vllm_outputs_per_audio):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
def run_multi_audio_test( def run_multi_audio_test(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
prompts_and_audios: list[tuple[str, list[AudioTuple]]], prompts_and_audios: list[tuple[str, list[AudioTuple]]],
@ -194,35 +117,6 @@ def run_multi_audio_test(
assert all(tokens for tokens, *_ in vllm_outputs) assert all(tokens for tokens, *_ in vllm_outputs)
@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets,
dtype: str, max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:
audio_inputs = [(
_get_prompt(1, audio, VLLM_PLACEHOLDER),
_get_prompt(1, audio, HF_PLACEHOLDER),
audio.audio_and_sample_rate,
) for audio in audio_assets]
run_test(
hf_runner,
vllm_runner,
audio_inputs,
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**vllm_kwargs,
)
@pytest.mark.core_model @pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])

View File

@ -7,18 +7,21 @@ from typing import Callable, Optional, Union
import torch import torch
from vllm.multimodal.audio import AudioResampler
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video, from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video) sample_frames_from_video)
from .....conftest import ImageTestAssets, VideoTestAssets from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS,
TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER,
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
ImageSizeWrapper, SizeType, VLMTestInfo) ImageSizeWrapper, PromptWithMultiModalInput, SizeType,
VLMTestInfo)
def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int],
str], str],
test_placeholder: str) -> str: test_placeholder: str) -> str:
"""Given a prompt, replaces each test placeholder with the """Given a prompt, replaces each test placeholder with the
model-specific tag. model-specific tag.
@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
prompt_segments = prompt.split(test_placeholder) prompt_segments = prompt.split(test_placeholder)
img_prompt = prompt_segments[0] img_prompt = prompt_segments[0]
for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1): for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1):
img_prompt += img_idx_to_prompt(placeholder_idx) img_prompt += mm_idx_to_prompt(placeholder_idx)
img_prompt += next_seg img_prompt += next_seg
return img_prompt return img_prompt
@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
def get_model_prompts(base_prompts: Iterable[str], def get_model_prompts(base_prompts: Iterable[str],
img_idx_to_prompt: Optional[Callable[[int], str]], img_idx_to_prompt: Optional[Callable[[int], str]],
video_idx_to_prompt: Optional[Callable[[int], str]], video_idx_to_prompt: Optional[Callable[[int], str]],
audio_idx_to_prompt: Optional[Callable[[int], str]],
prompt_formatter: Callable[[str], str]) -> list[str]: prompt_formatter: Callable[[str], str]) -> list[str]:
"""Given a model-agnostic base prompt and test configuration for a model(s) """Given a model-agnostic base prompt and test configuration for a model(s)
to be tested, update the media placeholders and apply the prompt formatting to be tested, update the media placeholders and apply the prompt formatting
@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str],
video_idx_to_prompt, video_idx_to_prompt,
TEST_VIDEO_PLACEHOLDER) TEST_VIDEO_PLACEHOLDER)
if audio_idx_to_prompt:
base_prompt = replace_test_placeholder(base_prompt,
audio_idx_to_prompt,
TEST_AUDIO_PLACEHOLDER)
# Apply the prompt formatter to wrap the base prompt with # Apply the prompt formatter to wrap the base prompt with
# the correct media placeholders to get the model test prompt # the correct media placeholders to get the model test prompt
model_prompt = prompt_formatter(base_prompt) model_prompt = prompt_formatter(base_prompt)
@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str],
def build_single_image_inputs_from_test_info( def build_single_image_inputs_from_test_info(
test_info: VLMTestInfo, test_info: VLMTestInfo,
image_assets: ImageTestAssets, image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper, size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None): tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None: if test_info.prompt_formatter is None:
raise ValueError( raise ValueError(
"Prompt formatter must be set to build single image inputs") "Prompt formatter must be set to build single image inputs")
@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info(
model_prompts = get_model_prompts(test_info.single_image_prompts, model_prompts = get_model_prompts(test_info.single_image_prompts,
test_info.img_idx_to_prompt, test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt, test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter) test_info.prompt_formatter)
# For models that require a local path / URL encoded in the image; export # For models that require a local path / URL encoded in the image; export
@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info(
return build_single_image_inputs(images, model_prompts, size_wrapper) return build_single_image_inputs(images, model_prompts, size_wrapper)
def build_single_image_inputs(images, model_prompts, def build_single_image_inputs(
size_wrapper: ImageSizeWrapper): images, model_prompts,
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
# For every image / prompt pair, get a pair containing two lists of # For every image / prompt pair, get a pair containing two lists of
# length size_factors, where the first contains duplicates of the model # length size_factors, where the first contains duplicates of the model
# prompt [str], and the second contains copies of the image after being # prompt [str], and the second contains copies of the image after being
# scaled by one of the size factors. # scaled by one of the size factors.
# #
# NOTE: rescaling preserves the image aspect ratio. # NOTE: rescaling preserves the image aspect ratio.
return [( return [
[prompt for _ in size_wrapper.data], PromptWithMultiModalInput(
[ prompts=[prompt for _ in size_wrapper.data],
apply_image_size_scaling(image, size, size_wrapper.type) image_data=[
for size in size_wrapper.data apply_image_size_scaling(image, size, size_wrapper.type)
], for size in size_wrapper.data
) for image, prompt in zip(images, model_prompts)] ],
) for image, prompt in zip(images, model_prompts)
]
def build_multi_image_inputs_from_test_info( def build_multi_image_inputs_from_test_info(
test_info: VLMTestInfo, test_info: VLMTestInfo,
image_assets: ImageTestAssets, image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper, size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None): tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None: if test_info.prompt_formatter is None:
raise ValueError( raise ValueError(
"Prompt formatter must be set to build multi image inputs") "Prompt formatter must be set to build multi image inputs")
@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info(
model_prompts = get_model_prompts([test_info.multi_image_prompt], model_prompts = get_model_prompts([test_info.multi_image_prompt],
test_info.img_idx_to_prompt, test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt, test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter) test_info.prompt_formatter)
if test_info.prompt_path_encoder is not None: if test_info.prompt_path_encoder is not None:
@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info(
) )
def build_multi_image_inputs(image_lists, model_prompts, def build_multi_image_inputs(
size_wrapper: ImageSizeWrapper): image_lists, model_prompts,
return [( size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
[prompt for _ in size_wrapper.data], return [
[[ PromptWithMultiModalInput(
apply_image_size_scaling(image, size, size_wrapper.type) prompts=[prompt for _ in size_wrapper.data],
for image in images image_data=[[
] for size in size_wrapper.data], apply_image_size_scaling(image, size, size_wrapper.type)
) for images, prompt in zip(image_lists, model_prompts)] for image in images
] for size in size_wrapper.data],
) for images, prompt in zip(image_lists, model_prompts)
]
def build_embedding_inputs_from_test_info( def build_embedding_inputs_from_test_info(
@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info(
SINGLE_IMAGE_BASE_PROMPTS, SINGLE_IMAGE_BASE_PROMPTS,
test_info.img_idx_to_prompt, test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt, test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter, test_info.prompt_formatter,
) )
@ -195,13 +215,14 @@ def build_video_inputs_from_test_info(
video_assets: VideoTestAssets, video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper, size_wrapper: ImageSizeWrapper,
num_frames: int, num_frames: int,
): ) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None: if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build video inputs") raise ValueError("Prompt formatter must be set to build video inputs")
model_prompts = get_model_prompts( model_prompts = get_model_prompts(
[VIDEO_BASE_PROMPT], [VIDEO_BASE_PROMPT],
test_info.img_idx_to_prompt, test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt, test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter, test_info.prompt_formatter,
) )
@ -213,10 +234,14 @@ def build_video_inputs_from_test_info(
video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE
else rescale_video_size) else rescale_video_size)
return [( return [
[prompt for _ in size_wrapper.data], PromptWithMultiModalInput(
[video_scaler(video, size) for size in size_wrapper.data], prompts=[prompt for _ in size_wrapper.data],
) for video, prompt in zip(sampled_vids, model_prompts)] video_data=[
video_scaler(video, size) for size in size_wrapper.data
],
) for video, prompt in zip(sampled_vids, model_prompts)
]
def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
# We have a list of fixed sizes # We have a list of fixed sizes
return image.resize(size) return image.resize(size)
raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR") raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR")
def build_audio_inputs_from_test_info(
test_info: VLMTestInfo,
audio_assets: AudioTestAssets,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build audio inputs")
model_prompts = get_model_prompts(
SINGLE_AUDIO_BASE_PROMPT,
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter,
)
resampler = AudioResampler(
target_sr=16000,
method="librosa",
)
audios = [asset.audio_and_sample_rate for asset in audio_assets]
resampled_audios = [(
resampler.resample(
audio,
orig_sr=sr,
),
int(resampler.target_sr),
) for audio, sr in audios]
return [
PromptWithMultiModalInput(
prompts=model_prompts,
audio_data=resampled_audios,
)
]

View File

@ -83,7 +83,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
test_info.num_video_frames) test_info.num_video_frames)
# No sizes passed for custom inputs, since inputs are directly provided # No sizes passed for custom inputs, since inputs are directly provided
if test_type != VLMTestType.CUSTOM_INPUTS: if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
if wrapped_sizes is None: if wrapped_sizes is None:
raise ValueError( raise ValueError(
@ -91,7 +91,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
iter_kwargs["size_wrapper"] = wrapped_sizes iter_kwargs["size_wrapper"] = wrapped_sizes
#Otherwise expand the custom test options instead #Otherwise expand the custom test options instead
else: elif test_type == VLMTestType.CUSTOM_INPUTS:
if test_info.custom_test_opts is None: if test_info.custom_test_opts is None:
raise ValueError("Test has type CUSTOM_INPUTS, but none given") raise ValueError("Test has type CUSTOM_INPUTS, but none given")
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
@ -136,8 +136,8 @@ def get_wrapped_test_sizes(
ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor)
for factor in EMBEDDING_SIZE_FACTORS for factor in EMBEDDING_SIZE_FACTORS
]) ])
# Custom inputs have preprocessed inputs # Audio and Custom inputs have preprocessed inputs
elif test_type == VLMTestType.CUSTOM_INPUTS: elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS):
return tuple() return tuple()
size_factors = test_info.image_size_factors \ size_factors = test_info.image_size_factors \

View File

@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Core test implementation to be shared across modalities.""" """Core test implementation to be shared across modalities."""
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional
import torch import torch
from PIL.Image import Image
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption from vllm.config import TaskOption
@ -11,14 +10,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import HfRunner, VllmRunner from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS from ....registry import HF_EXAMPLE_MODELS
from .types import RunnerOutput from .types import PromptWithMultiModalInput, RunnerOutput
def run_test( def run_test(
*, *,
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]], inputs: list[PromptWithMultiModalInput],
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
@ -38,7 +37,6 @@ def run_test(
hf_model_kwargs: Optional[dict[str, Any]], hf_model_kwargs: Optional[dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: TaskOption = "auto", task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
vllm_embeddings: Optional[torch.Tensor] = None, vllm_embeddings: Optional[torch.Tensor] = None,
@ -94,10 +92,16 @@ def run_test(
if stop_str: if stop_str:
vllm_kwargs["stop"] = stop_str vllm_kwargs["stop"] = stop_str
for prompts, media in vllm_inputs: for prompts, image_data, video_data, audio_data in vllm_inputs:
vllm_kwargs[runner_mm_key] = media mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
vllm_kwargs_with_mm_data = vllm_kwargs | mm_data
vllm_output = vllm_model.generate_greedy_logprobs( vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs) prompts,
max_tokens,
num_logprobs=num_logprobs,
**vllm_kwargs_with_mm_data)
vllm_outputs_per_mm.append(vllm_output) vllm_outputs_per_mm.append(vllm_output)
hf_model = hf_runner(model, hf_model = hf_runner(model,
@ -122,14 +126,17 @@ def run_test(
if stop_str: if stop_str:
hf_kwargs["stop_strings"] = stop_str hf_kwargs["stop_strings"] = stop_str
for prompts, media in inputs: for prompts, image_data, video_data, audio_data in inputs:
hf_kwargs[runner_mm_key] = media mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
hf_kwargs_with_mm_data = hf_kwargs | mm_data
hf_output = hf_model.generate_greedy_logprobs_limit( hf_output = hf_model.generate_greedy_logprobs_limit(
prompts, prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
**hf_kwargs) **hf_kwargs_with_mm_data)
hf_outputs_per_mm.append(hf_output) hf_outputs_per_mm.append(hf_output)
# Apply output processing / sanitation to the vLLM and HF runner results # Apply output processing / sanitation to the vLLM and HF runner results

View File

@ -12,7 +12,7 @@ from vllm.multimodal.video import (rescale_video_size, resize_video,
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
from .builders import build_multi_image_inputs, build_single_image_inputs from .builders import build_multi_image_inputs, build_single_image_inputs
from .types import ImageSizeWrapper, SizeType from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType
def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
"<image>\nWhat is the season?", "<image>\nWhat is the season?",
] ]
formatted_prompts = [formatter(prompt) for prompt in img_prompts] formatted_prompts = [formatter(prompt) for prompt in img_prompts]
aspect_ratio_images = [
return [( [stop_sign, cherry_blossom],
formatted_prompts, # Images with different sizes and aspect-ratios
[ [
[stop_sign, cherry_blossom], rescale_image_size(stop_sign, 0.1),
# Images with different sizes and aspect-ratios stop_sign,
[ ],
rescale_image_size(stop_sign, 0.1), [
stop_sign, stop_sign,
], rescale_image_size(stop_sign, 0.25),
[ cherry_blossom.resize((183, 488)),
stop_sign, cherry_blossom.resize((488, 183))
rescale_image_size(stop_sign, 0.25), ],
cherry_blossom.resize((183, 488)), cherry_blossom,
cherry_blossom.resize((488, 183)) ]
],
cherry_blossom, return [
])] PromptWithMultiModalInput(
prompts=formatted_prompts,
image_data=aspect_ratio_images,
)
]
def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
"<video>\nWhy is this video funny?", "<video>\nWhy is this video funny?",
] ]
formatted_prompts = [formatter(prompt) for prompt in video_prompts] formatted_prompts = [formatter(prompt) for prompt in video_prompts]
aspect_ratio_videos = [
return [( [video, video],
formatted_prompts, # Videos with different sizes and aspect-ratios
[ [
[video, video], rescale_video_size(video, 0.1),
# Videos with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video, video,
])] ],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
]
return [
PromptWithMultiModalInput(
prompts=formatted_prompts,
video_data=aspect_ratio_videos,
)
]
def different_patch_input_cases_internvl(): def different_patch_input_cases_internvl():

View File

@ -237,6 +237,18 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs return output_ids, output_str, out_logprobs
def ultravox_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.decode(eos_token_id)
if output_str.endswith(eos_token):
output_str = output_str.split(eos_token)[0]
return output_ids, output_str, out_logprobs
####### Functions for converting image assets to embeddings ####### Functions for converting image assets to embeddings
def get_llava_embeddings(image_assets: ImageTestAssets): def get_llava_embeddings(image_assets: ImageTestAssets):
return [asset.image_embeds for asset in image_assets] return [asset.image_embeds for asset in image_assets]

View File

@ -4,8 +4,8 @@ types / modalities.
""" """
from pathlib import PosixPath from pathlib import PosixPath
from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets, from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets,
VllmRunner) VideoTestAssets, VllmRunner)
from . import builders, core from . import builders, core
from .types import ExpandableVLMTestArgs, VLMTestInfo from .types import ExpandableVLMTestArgs, VLMTestInfo
@ -30,7 +30,6 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs, num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
distributed_executor_backend=test_case.distributed_executor_backend, distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs()) **model_test_info.get_non_parametrized_runner_kwargs())
@ -53,7 +52,6 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs, num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"image": len(image_assets)}, limit_mm_per_prompt={"image": len(image_assets)},
distributed_executor_backend=test_case.distributed_executor_backend, distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs()) **model_test_info.get_non_parametrized_runner_kwargs())
@ -77,7 +75,6 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
vllm_embeddings=vllm_embeddings, vllm_embeddings=vllm_embeddings,
distributed_executor_backend=test_case.distributed_executor_backend, distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs()) **model_test_info.get_non_parametrized_runner_kwargs())
@ -105,7 +102,30 @@ def run_video_test(
num_logprobs=test_case.num_logprobs, num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"video": len(video_assets)}, limit_mm_per_prompt={"video": len(video_assets)},
distributed_executor_backend=test_case.distributed_executor_backend, distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="videos", **model_test_info.get_non_parametrized_runner_kwargs())
def run_audio_test(
*,
model_test_info: VLMTestInfo,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
inputs = builders.build_audio_inputs_from_test_info(
model_test_info, audio_assets)
core.run_test(
hf_runner=hf_runner,
vllm_runner=vllm_runner,
inputs=inputs,
model=test_case.model,
dtype=test_case.dtype,
max_tokens=test_case.max_tokens,
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"audio": 1},
distributed_executor_backend=test_case.distributed_executor_backend,
**model_test_info.get_non_parametrized_runner_kwargs()) **model_test_info.get_non_parametrized_runner_kwargs())
@ -120,11 +140,9 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
inputs = test_case.custom_test_opts.inputs inputs = test_case.custom_test_opts.inputs
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
runner_mm_key = test_case.custom_test_opts.runner_mm_key # Inputs and limit_mm_per_prompt should all be set
# Inputs, limit_mm_per_prompt, and runner_mm_key should all be set
assert inputs is not None assert inputs is not None
assert limit_mm_per_prompt is not None assert limit_mm_per_prompt is not None
assert runner_mm_key is not None
core.run_test( core.run_test(
hf_runner=hf_runner, hf_runner=hf_runner,
@ -136,5 +154,4 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs, num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
distributed_executor_backend=test_case.distributed_executor_backend, distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key=runner_mm_key,
**model_test_info.get_non_parametrized_runner_kwargs()) **model_test_info.get_non_parametrized_runner_kwargs())

View File

@ -6,7 +6,6 @@ from pathlib import PosixPath
from typing import Any, Callable, NamedTuple, Optional, Union from typing import Any, Callable, NamedTuple, Optional, Union
import torch import torch
from PIL.Image import Image
from pytest import MarkDecorator from pytest import MarkDecorator
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
@ -15,18 +14,25 @@ from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
ImageTestAssets, PromptAudioInput, PromptImageInput,
PromptVideoInput)
from ....utils import check_logprobs_close from ....utils import check_logprobs_close
# meta image tag; will be replaced by the appropriate tag for the model # meta image tag; will be replaced by the appropriate tag for the model
TEST_IMG_PLACEHOLDER = "<vlm_image>" TEST_IMG_PLACEHOLDER = "<vlm_image>"
TEST_VIDEO_PLACEHOLDER = "<vlm_video>" TEST_VIDEO_PLACEHOLDER = "<vlm_video>"
TEST_AUDIO_PLACEHOLDER = "<lmm_audio>"
# yapf: disable # yapf: disable
SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({ SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?",
"cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?",
}) })
SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({
"mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501
"winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501
})
MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501 MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
@ -38,12 +44,21 @@ RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]]
# yapf: enable # yapf: enable
class PromptWithMultiModalInput(NamedTuple):
"""Holds the multimodal input for a single test case."""
prompts: list[str]
image_data: Optional[PromptImageInput] = None
video_data: Optional[PromptVideoInput] = None
audio_data: Optional[PromptAudioInput] = None
class VLMTestType(Enum): class VLMTestType(Enum):
IMAGE = 1 IMAGE = 1
MULTI_IMAGE = 2 MULTI_IMAGE = 2
EMBEDDING = 3 EMBEDDING = 3
VIDEO = 4 VIDEO = 4
CUSTOM_INPUTS = 5 AUDIO = 5
CUSTOM_INPUTS = 6
class SizeType(Enum): class SizeType(Enum):
@ -52,10 +67,8 @@ class SizeType(Enum):
class CustomTestOptions(NamedTuple): class CustomTestOptions(NamedTuple):
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]] inputs: list[PromptWithMultiModalInput]
limit_mm_per_prompt: dict[str, int] limit_mm_per_prompt: dict[str, int]
# kwarg to pass multimodal data in as to vllm/hf runner instances.
runner_mm_key: str = "images"
class ImageSizeWrapper(NamedTuple): class ImageSizeWrapper(NamedTuple):
@ -75,6 +88,7 @@ class VLMTestInfo(NamedTuple):
prompt_formatter: Optional[Callable[[str], str]] = None prompt_formatter: Optional[Callable[[str], str]] = None
img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n" img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n"
video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n" video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n"
audio_idx_to_prompt: Callable[[int], str] = lambda idx: "<audio>\n"
# Most models work on the single / multi-image prompts above, but in some # Most models work on the single / multi-image prompts above, but in some
# cases the log prob check fails, e.g., for paligemma. We allow passing # cases the log prob check fails, e.g., for paligemma. We allow passing