[Misc] Consolidate Audio tests into multimodal common generation tests (#18214)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-16 17:18:08 +08:00 committed by GitHub
parent 541817670c
commit 390ec88905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 282 additions and 215 deletions

View File

@ -8,14 +8,14 @@ from collections import defaultdict
from pathlib import PosixPath
import pytest
from transformers import (AutoModelForImageTextToText,
from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from vllm.platforms import current_platform
from vllm.utils import identity
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
VideoTestAssets, VllmRunner)
from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner,
ImageTestAssets, VideoTestAssets, VllmRunner)
from ....utils import (create_new_process_for_each_test, large_gpu_mark,
multi_gpu_marks)
from ...utils import check_outputs_equal
@ -158,6 +158,17 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"ultravox": VLMTestInfo(
models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
test_type=VLMTestType.AUDIO,
prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
audio_idx_to_prompt=lambda idx: "<|audio|>",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModel,
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
@ -393,7 +404,6 @@ VLM_TEST_SETTINGS = {
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
),
limit_mm_per_prompt={"video": 4},
runner_mm_key="videos",
)],
),
"llava_next_video": VLMTestInfo(
@ -706,6 +716,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
# - multi-image
# - image embeddings
# - video
# - audio
# - custom inputs
@pytest.mark.parametrize(
"model_type,test_case",
@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
)
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(
VLM_TEST_SETTINGS,
test_type=VLMTestType.AUDIO,
create_new_process_for_each_test=False,
))
def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
runners.run_audio_test(
model_test_info=model_test_info,
test_case=test_case,
hf_runner=hf_runner,
vllm_runner=vllm_runner,
audio_assets=audio_assets,
)
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(
@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
)
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(
VLM_TEST_SETTINGS,
test_type=VLMTestType.AUDIO,
create_new_process_for_each_test=True,
))
def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
runners.run_audio_test(
model_test_info=model_test_info,
test_case=test_case,
hf_runner=hf_runner,
vllm_runner=vllm_runner,
audio_assets=audio_assets,
)
@pytest.mark.parametrize(
"model_type,test_case",
get_parametrized_options(

View File

@ -1,20 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Any, Optional
from typing import Any
import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer
from transformers import AutoTokenizer
from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner
from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
from ....utils import RemoteOpenAIServer
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder):
add_generation_prompt=True)
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = output_ids[:]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
prompts_and_audios: list[tuple[str, str, AudioTuple]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, dtype=dtype, enforce_eager=True,
**kwargs) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[audio])
for vllm_prompt, _, audio in prompts_and_audios
]
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(resample_audio_librosa(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
vllm_outputs_per_audio):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
def run_multi_audio_test(
vllm_runner: type[VllmRunner],
prompts_and_audios: list[tuple[str, list[AudioTuple]]],
@ -194,35 +117,6 @@ def run_multi_audio_test(
assert all(tokens for tokens, *_ in vllm_outputs)
@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets,
dtype: str, max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:
audio_inputs = [(
_get_prompt(1, audio, VLLM_PLACEHOLDER),
_get_prompt(1, audio, HF_PLACEHOLDER),
audio.audio_and_sample_rate,
) for audio in audio_assets]
run_test(
hf_runner,
vllm_runner,
audio_inputs,
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**vllm_kwargs,
)
@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])

View File

@ -7,18 +7,21 @@ from typing import Callable, Optional, Union
import torch
from vllm.multimodal.audio import AudioResampler
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)
from .....conftest import ImageTestAssets, VideoTestAssets
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,
from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets
from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS,
TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER,
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
ImageSizeWrapper, SizeType, VLMTestInfo)
ImageSizeWrapper, PromptWithMultiModalInput, SizeType,
VLMTestInfo)
def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
str],
def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int],
str],
test_placeholder: str) -> str:
"""Given a prompt, replaces each test placeholder with the
model-specific tag.
@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
prompt_segments = prompt.split(test_placeholder)
img_prompt = prompt_segments[0]
for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1):
img_prompt += img_idx_to_prompt(placeholder_idx)
img_prompt += mm_idx_to_prompt(placeholder_idx)
img_prompt += next_seg
return img_prompt
@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
def get_model_prompts(base_prompts: Iterable[str],
img_idx_to_prompt: Optional[Callable[[int], str]],
video_idx_to_prompt: Optional[Callable[[int], str]],
audio_idx_to_prompt: Optional[Callable[[int], str]],
prompt_formatter: Callable[[str], str]) -> list[str]:
"""Given a model-agnostic base prompt and test configuration for a model(s)
to be tested, update the media placeholders and apply the prompt formatting
@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str],
video_idx_to_prompt,
TEST_VIDEO_PLACEHOLDER)
if audio_idx_to_prompt:
base_prompt = replace_test_placeholder(base_prompt,
audio_idx_to_prompt,
TEST_AUDIO_PLACEHOLDER)
# Apply the prompt formatter to wrap the base prompt with
# the correct media placeholders to get the model test prompt
model_prompt = prompt_formatter(base_prompt)
@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str],
def build_single_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError(
"Prompt formatter must be set to build single image inputs")
@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info(
model_prompts = get_model_prompts(test_info.single_image_prompts,
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter)
# For models that require a local path / URL encoded in the image; export
@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info(
return build_single_image_inputs(images, model_prompts, size_wrapper)
def build_single_image_inputs(images, model_prompts,
size_wrapper: ImageSizeWrapper):
def build_single_image_inputs(
images, model_prompts,
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
# For every image / prompt pair, get a pair containing two lists of
# length size_factors, where the first contains duplicates of the model
# prompt [str], and the second contains copies of the image after being
# scaled by one of the size factors.
#
# NOTE: rescaling preserves the image aspect ratio.
return [(
[prompt for _ in size_wrapper.data],
[
apply_image_size_scaling(image, size, size_wrapper.type)
for size in size_wrapper.data
],
) for image, prompt in zip(images, model_prompts)]
return [
PromptWithMultiModalInput(
prompts=[prompt for _ in size_wrapper.data],
image_data=[
apply_image_size_scaling(image, size, size_wrapper.type)
for size in size_wrapper.data
],
) for image, prompt in zip(images, model_prompts)
]
def build_multi_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError(
"Prompt formatter must be set to build multi image inputs")
@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info(
model_prompts = get_model_prompts([test_info.multi_image_prompt],
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter)
if test_info.prompt_path_encoder is not None:
@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info(
)
def build_multi_image_inputs(image_lists, model_prompts,
size_wrapper: ImageSizeWrapper):
return [(
[prompt for _ in size_wrapper.data],
[[
apply_image_size_scaling(image, size, size_wrapper.type)
for image in images
] for size in size_wrapper.data],
) for images, prompt in zip(image_lists, model_prompts)]
def build_multi_image_inputs(
image_lists, model_prompts,
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
return [
PromptWithMultiModalInput(
prompts=[prompt for _ in size_wrapper.data],
image_data=[[
apply_image_size_scaling(image, size, size_wrapper.type)
for image in images
] for size in size_wrapper.data],
) for images, prompt in zip(image_lists, model_prompts)
]
def build_embedding_inputs_from_test_info(
@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info(
SINGLE_IMAGE_BASE_PROMPTS,
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter,
)
@ -195,13 +215,14 @@ def build_video_inputs_from_test_info(
video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper,
num_frames: int,
):
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build video inputs")
model_prompts = get_model_prompts(
[VIDEO_BASE_PROMPT],
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter,
)
@ -213,10 +234,14 @@ def build_video_inputs_from_test_info(
video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE
else rescale_video_size)
return [(
[prompt for _ in size_wrapper.data],
[video_scaler(video, size) for size in size_wrapper.data],
) for video, prompt in zip(sampled_vids, model_prompts)]
return [
PromptWithMultiModalInput(
prompts=[prompt for _ in size_wrapper.data],
video_data=[
video_scaler(video, size) for size in size_wrapper.data
],
) for video, prompt in zip(sampled_vids, model_prompts)
]
def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
# We have a list of fixed sizes
return image.resize(size)
raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR")
def build_audio_inputs_from_test_info(
test_info: VLMTestInfo,
audio_assets: AudioTestAssets,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build audio inputs")
model_prompts = get_model_prompts(
SINGLE_AUDIO_BASE_PROMPT,
test_info.img_idx_to_prompt,
test_info.video_idx_to_prompt,
test_info.audio_idx_to_prompt,
test_info.prompt_formatter,
)
resampler = AudioResampler(
target_sr=16000,
method="librosa",
)
audios = [asset.audio_and_sample_rate for asset in audio_assets]
resampled_audios = [(
resampler.resample(
audio,
orig_sr=sr,
),
int(resampler.target_sr),
) for audio, sr in audios]
return [
PromptWithMultiModalInput(
prompts=model_prompts,
audio_data=resampled_audios,
)
]

View File

@ -83,7 +83,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
test_info.num_video_frames)
# No sizes passed for custom inputs, since inputs are directly provided
if test_type != VLMTestType.CUSTOM_INPUTS:
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
if wrapped_sizes is None:
raise ValueError(
@ -91,7 +91,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
iter_kwargs["size_wrapper"] = wrapped_sizes
#Otherwise expand the custom test options instead
else:
elif test_type == VLMTestType.CUSTOM_INPUTS:
if test_info.custom_test_opts is None:
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
@ -136,8 +136,8 @@ def get_wrapped_test_sizes(
ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor)
for factor in EMBEDDING_SIZE_FACTORS
])
# Custom inputs have preprocessed inputs
elif test_type == VLMTestType.CUSTOM_INPUTS:
# Audio and Custom inputs have preprocessed inputs
elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS):
return tuple()
size_factors = test_info.image_size_factors \

View File

@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""Core test implementation to be shared across modalities."""
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional
import torch
from PIL.Image import Image
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
@ -11,14 +10,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS
from .types import RunnerOutput
from .types import PromptWithMultiModalInput, RunnerOutput
def run_test(
*,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]],
inputs: list[PromptWithMultiModalInput],
model: str,
dtype: str,
max_tokens: int,
@ -38,7 +37,6 @@ def run_test(
hf_model_kwargs: Optional[dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1,
vllm_embeddings: Optional[torch.Tensor] = None,
@ -94,10 +92,16 @@ def run_test(
if stop_str:
vllm_kwargs["stop"] = stop_str
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
for prompts, image_data, video_data, audio_data in vllm_inputs:
mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
vllm_kwargs_with_mm_data = vllm_kwargs | mm_data
vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs)
prompts,
max_tokens,
num_logprobs=num_logprobs,
**vllm_kwargs_with_mm_data)
vllm_outputs_per_mm.append(vllm_output)
hf_model = hf_runner(model,
@ -122,14 +126,17 @@ def run_test(
if stop_str:
hf_kwargs["stop_strings"] = stop_str
for prompts, media in inputs:
hf_kwargs[runner_mm_key] = media
for prompts, image_data, video_data, audio_data in inputs:
mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
hf_kwargs_with_mm_data = hf_kwargs | mm_data
hf_output = hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
tokenizer=tokenizer,
**hf_kwargs)
**hf_kwargs_with_mm_data)
hf_outputs_per_mm.append(hf_output)
# Apply output processing / sanitation to the vLLM and HF runner results

View File

@ -12,7 +12,7 @@ from vllm.multimodal.video import (rescale_video_size, resize_video,
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
from .builders import build_multi_image_inputs, build_single_image_inputs
from .types import ImageSizeWrapper, SizeType
from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType
def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
"<image>\nWhat is the season?",
]
formatted_prompts = [formatter(prompt) for prompt in img_prompts]
return [(
formatted_prompts,
aspect_ratio_images = [
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
]
return [
PromptWithMultiModalInput(
prompts=formatted_prompts,
image_data=aspect_ratio_images,
)
]
def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
"<video>\nWhy is this video funny?",
]
formatted_prompts = [formatter(prompt) for prompt in video_prompts]
return [(
formatted_prompts,
aspect_ratio_videos = [
[video, video],
# Videos with different sizes and aspect-ratios
[
[video, video],
# Videos with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
rescale_video_size(video, 0.1),
video,
])]
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
]
return [
PromptWithMultiModalInput(
prompts=formatted_prompts,
video_data=aspect_ratio_videos,
)
]
def different_patch_input_cases_internvl():

View File

@ -237,6 +237,18 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs
def ultravox_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.decode(eos_token_id)
if output_str.endswith(eos_token):
output_str = output_str.split(eos_token)[0]
return output_ids, output_str, out_logprobs
####### Functions for converting image assets to embeddings
def get_llava_embeddings(image_assets: ImageTestAssets):
return [asset.image_embeds for asset in image_assets]

View File

@ -4,8 +4,8 @@ types / modalities.
"""
from pathlib import PosixPath
from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets,
VllmRunner)
from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets,
VideoTestAssets, VllmRunner)
from . import builders, core
from .types import ExpandableVLMTestArgs, VLMTestInfo
@ -30,7 +30,6 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"image": 1},
distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs())
@ -53,7 +52,6 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"image": len(image_assets)},
distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs())
@ -77,7 +75,6 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
limit_mm_per_prompt={"image": 1},
vllm_embeddings=vllm_embeddings,
distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="images",
**model_test_info.get_non_parametrized_runner_kwargs())
@ -105,7 +102,30 @@ def run_video_test(
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"video": len(video_assets)},
distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key="videos",
**model_test_info.get_non_parametrized_runner_kwargs())
def run_audio_test(
*,
model_test_info: VLMTestInfo,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
inputs = builders.build_audio_inputs_from_test_info(
model_test_info, audio_assets)
core.run_test(
hf_runner=hf_runner,
vllm_runner=vllm_runner,
inputs=inputs,
model=test_case.model,
dtype=test_case.dtype,
max_tokens=test_case.max_tokens,
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt={"audio": 1},
distributed_executor_backend=test_case.distributed_executor_backend,
**model_test_info.get_non_parametrized_runner_kwargs())
@ -120,11 +140,9 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
inputs = test_case.custom_test_opts.inputs
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
runner_mm_key = test_case.custom_test_opts.runner_mm_key
# Inputs, limit_mm_per_prompt, and runner_mm_key should all be set
# Inputs and limit_mm_per_prompt should all be set
assert inputs is not None
assert limit_mm_per_prompt is not None
assert runner_mm_key is not None
core.run_test(
hf_runner=hf_runner,
@ -136,5 +154,4 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
num_logprobs=test_case.num_logprobs,
limit_mm_per_prompt=limit_mm_per_prompt,
distributed_executor_backend=test_case.distributed_executor_backend,
runner_mm_key=runner_mm_key,
**model_test_info.get_non_parametrized_runner_kwargs())

View File

@ -6,7 +6,6 @@ from pathlib import PosixPath
from typing import Any, Callable, NamedTuple, Optional, Union
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass
@ -15,18 +14,25 @@ from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
ImageTestAssets, PromptAudioInput, PromptImageInput,
PromptVideoInput)
from ....utils import check_logprobs_close
# meta image tag; will be replaced by the appropriate tag for the model
TEST_IMG_PLACEHOLDER = "<vlm_image>"
TEST_VIDEO_PLACEHOLDER = "<vlm_video>"
TEST_AUDIO_PLACEHOLDER = "<lmm_audio>"
# yapf: disable
SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?",
"cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?",
})
SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({
"mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501
"winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501
})
MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
@ -38,12 +44,21 @@ RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]]
# yapf: enable
class PromptWithMultiModalInput(NamedTuple):
"""Holds the multimodal input for a single test case."""
prompts: list[str]
image_data: Optional[PromptImageInput] = None
video_data: Optional[PromptVideoInput] = None
audio_data: Optional[PromptAudioInput] = None
class VLMTestType(Enum):
IMAGE = 1
MULTI_IMAGE = 2
EMBEDDING = 3
VIDEO = 4
CUSTOM_INPUTS = 5
AUDIO = 5
CUSTOM_INPUTS = 6
class SizeType(Enum):
@ -52,10 +67,8 @@ class SizeType(Enum):
class CustomTestOptions(NamedTuple):
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]]
inputs: list[PromptWithMultiModalInput]
limit_mm_per_prompt: dict[str, int]
# kwarg to pass multimodal data in as to vllm/hf runner instances.
runner_mm_key: str = "images"
class ImageSizeWrapper(NamedTuple):
@ -75,6 +88,7 @@ class VLMTestInfo(NamedTuple):
prompt_formatter: Optional[Callable[[str], str]] = None
img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n"
video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n"
audio_idx_to_prompt: Callable[[int], str] = lambda idx: "<audio>\n"
# Most models work on the single / multi-image prompts above, but in some
# cases the log prob check fails, e.g., for paligemma. We allow passing