mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:15:42 +08:00
[Misc] Consolidate Audio tests into multimodal common generation tests (#18214)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
541817670c
commit
390ec88905
@ -8,14 +8,14 @@ from collections import defaultdict
|
|||||||
from pathlib import PosixPath
|
from pathlib import PosixPath
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import (AutoModelForImageTextToText,
|
from transformers import (AutoModel, AutoModelForImageTextToText,
|
||||||
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import identity
|
from vllm.utils import identity
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
|
from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner,
|
||||||
VideoTestAssets, VllmRunner)
|
ImageTestAssets, VideoTestAssets, VllmRunner)
|
||||||
from ....utils import (create_new_process_for_each_test, large_gpu_mark,
|
from ....utils import (create_new_process_for_each_test, large_gpu_mark,
|
||||||
multi_gpu_marks)
|
multi_gpu_marks)
|
||||||
from ...utils import check_outputs_equal
|
from ...utils import check_outputs_equal
|
||||||
@ -158,6 +158,17 @@ VLM_TEST_SETTINGS = {
|
|||||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
),
|
),
|
||||||
|
"ultravox": VLMTestInfo(
|
||||||
|
models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
|
||||||
|
test_type=VLMTestType.AUDIO,
|
||||||
|
prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||||
|
audio_idx_to_prompt=lambda idx: "<|audio|>",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
auto_cls=AutoModel,
|
||||||
|
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
|
),
|
||||||
#### Extended model tests
|
#### Extended model tests
|
||||||
"aria": VLMTestInfo(
|
"aria": VLMTestInfo(
|
||||||
models=["rhymes-ai/Aria"],
|
models=["rhymes-ai/Aria"],
|
||||||
@ -393,7 +404,6 @@ VLM_TEST_SETTINGS = {
|
|||||||
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||||
),
|
),
|
||||||
limit_mm_per_prompt={"video": 4},
|
limit_mm_per_prompt={"video": 4},
|
||||||
runner_mm_key="videos",
|
|
||||||
)],
|
)],
|
||||||
),
|
),
|
||||||
"llava_next_video": VLMTestInfo(
|
"llava_next_video": VLMTestInfo(
|
||||||
@ -706,6 +716,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2)
|
|||||||
# - multi-image
|
# - multi-image
|
||||||
# - image embeddings
|
# - image embeddings
|
||||||
# - video
|
# - video
|
||||||
|
# - audio
|
||||||
# - custom inputs
|
# - custom inputs
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_type,test_case",
|
"model_type,test_case",
|
||||||
@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_type,test_case",
|
||||||
|
get_parametrized_options(
|
||||||
|
VLM_TEST_SETTINGS,
|
||||||
|
test_type=VLMTestType.AUDIO,
|
||||||
|
create_new_process_for_each_test=False,
|
||||||
|
))
|
||||||
|
def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||||
|
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets, monkeypatch):
|
||||||
|
if model_type in REQUIRES_V0_MODELS:
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||||
|
runners.run_audio_test(
|
||||||
|
model_test_info=model_test_info,
|
||||||
|
test_case=test_case,
|
||||||
|
hf_runner=hf_runner,
|
||||||
|
vllm_runner=vllm_runner,
|
||||||
|
audio_assets=audio_assets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_type,test_case",
|
"model_type,test_case",
|
||||||
get_parametrized_options(
|
get_parametrized_options(
|
||||||
@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_type,test_case",
|
||||||
|
get_parametrized_options(
|
||||||
|
VLM_TEST_SETTINGS,
|
||||||
|
test_type=VLMTestType.AUDIO,
|
||||||
|
create_new_process_for_each_test=True,
|
||||||
|
))
|
||||||
|
def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets, monkeypatch):
|
||||||
|
if model_type in REQUIRES_V0_MODELS:
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
model_test_info = VLM_TEST_SETTINGS[model_type]
|
||||||
|
runners.run_audio_test(
|
||||||
|
model_test_info=model_test_info,
|
||||||
|
test_case=test_case,
|
||||||
|
hf_runner=hf_runner,
|
||||||
|
vllm_runner=vllm_runner,
|
||||||
|
audio_assets=audio_assets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_type,test_case",
|
"model_type,test_case",
|
||||||
get_parametrized_options(
|
get_parametrized_options(
|
||||||
|
|||||||
@ -1,20 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.audio import resample_audio_librosa
|
from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
|
||||||
from vllm.sequence import SampleLogprobs
|
|
||||||
|
|
||||||
from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner
|
|
||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
from ...registry import HF_EXAMPLE_MODELS
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
from ...utils import check_logprobs_close
|
|
||||||
|
|
||||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
|
|
||||||
@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder):
|
|||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
|
|
||||||
Optional[SampleLogprobs]],
|
|
||||||
model: str):
|
|
||||||
"""Sanitize vllm output to be comparable with hf output."""
|
|
||||||
output_ids, output_str, out_logprobs = vllm_output
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
|
||||||
eos_token_id = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
hf_output_ids = output_ids[:]
|
|
||||||
hf_output_str = output_str
|
|
||||||
if hf_output_ids[-1] == eos_token_id:
|
|
||||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
|
||||||
|
|
||||||
return hf_output_ids, hf_output_str, out_logprobs
|
|
||||||
|
|
||||||
|
|
||||||
def run_test(
|
|
||||||
hf_runner: type[HfRunner],
|
|
||||||
vllm_runner: type[VllmRunner],
|
|
||||||
prompts_and_audios: list[tuple[str, str, AudioTuple]],
|
|
||||||
model: str,
|
|
||||||
*,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Inference result should be the same between hf and vllm."""
|
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
|
||||||
model_info.check_available_online(on_fail="skip")
|
|
||||||
model_info.check_transformers_version(on_fail="skip")
|
|
||||||
|
|
||||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
|
||||||
# vLLM needs a fresh new process without cuda initialization.
|
|
||||||
# if we run HF first, the cuda initialization will be done and it
|
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, enforce_eager=True,
|
|
||||||
**kwargs) as vllm_model:
|
|
||||||
vllm_outputs_per_audio = [
|
|
||||||
vllm_model.generate_greedy_logprobs([vllm_prompt],
|
|
||||||
max_tokens,
|
|
||||||
num_logprobs=num_logprobs,
|
|
||||||
audios=[audio])
|
|
||||||
for vllm_prompt, _, audio in prompts_and_audios
|
|
||||||
]
|
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
|
|
||||||
hf_outputs_per_audio = [
|
|
||||||
hf_model.generate_greedy_logprobs_limit(
|
|
||||||
[hf_prompt],
|
|
||||||
max_tokens,
|
|
||||||
num_logprobs=num_logprobs,
|
|
||||||
audios=[(resample_audio_librosa(audio[0],
|
|
||||||
orig_sr=audio[1],
|
|
||||||
target_sr=16000), 16000)])
|
|
||||||
for _, hf_prompt, audio in prompts_and_audios
|
|
||||||
]
|
|
||||||
|
|
||||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
|
|
||||||
vllm_outputs_per_audio):
|
|
||||||
check_logprobs_close(
|
|
||||||
outputs_0_lst=hf_outputs,
|
|
||||||
outputs_1_lst=[
|
|
||||||
vllm_to_hf_output(vllm_output, model)
|
|
||||||
for vllm_output in vllm_outputs
|
|
||||||
],
|
|
||||||
name_0="hf",
|
|
||||||
name_1="vllm",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_multi_audio_test(
|
def run_multi_audio_test(
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
prompts_and_audios: list[tuple[str, list[AudioTuple]]],
|
prompts_and_audios: list[tuple[str, list[AudioTuple]]],
|
||||||
@ -194,35 +117,6 @@ def run_multi_audio_test(
|
|||||||
assert all(tokens for tokens, *_ in vllm_outputs)
|
assert all(tokens for tokens, *_ in vllm_outputs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.core_model
|
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
|
||||||
@pytest.mark.parametrize("vllm_kwargs", [
|
|
||||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
|
||||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
|
||||||
])
|
|
||||||
def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets,
|
|
||||||
dtype: str, max_tokens: int, num_logprobs: int,
|
|
||||||
vllm_kwargs: dict) -> None:
|
|
||||||
audio_inputs = [(
|
|
||||||
_get_prompt(1, audio, VLLM_PLACEHOLDER),
|
|
||||||
_get_prompt(1, audio, HF_PLACEHOLDER),
|
|
||||||
audio.audio_and_sample_rate,
|
|
||||||
) for audio in audio_assets]
|
|
||||||
|
|
||||||
run_test(
|
|
||||||
hf_runner,
|
|
||||||
vllm_runner,
|
|
||||||
audio_inputs,
|
|
||||||
MODEL_NAME,
|
|
||||||
dtype=dtype,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
num_logprobs=num_logprobs,
|
|
||||||
**vllm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
|||||||
@ -7,18 +7,21 @@ from typing import Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal.audio import AudioResampler
|
||||||
from vllm.multimodal.image import rescale_image_size
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.multimodal.video import (rescale_video_size, resize_video,
|
from vllm.multimodal.video import (rescale_video_size, resize_video,
|
||||||
sample_frames_from_video)
|
sample_frames_from_video)
|
||||||
|
|
||||||
from .....conftest import ImageTestAssets, VideoTestAssets
|
from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets
|
||||||
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,
|
from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS,
|
||||||
|
TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER,
|
||||||
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
|
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
|
||||||
ImageSizeWrapper, SizeType, VLMTestInfo)
|
ImageSizeWrapper, PromptWithMultiModalInput, SizeType,
|
||||||
|
VLMTestInfo)
|
||||||
|
|
||||||
|
|
||||||
def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
|
def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int],
|
||||||
str],
|
str],
|
||||||
test_placeholder: str) -> str:
|
test_placeholder: str) -> str:
|
||||||
"""Given a prompt, replaces each test placeholder with the
|
"""Given a prompt, replaces each test placeholder with the
|
||||||
model-specific tag.
|
model-specific tag.
|
||||||
@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
|
|||||||
prompt_segments = prompt.split(test_placeholder)
|
prompt_segments = prompt.split(test_placeholder)
|
||||||
img_prompt = prompt_segments[0]
|
img_prompt = prompt_segments[0]
|
||||||
for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1):
|
for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1):
|
||||||
img_prompt += img_idx_to_prompt(placeholder_idx)
|
img_prompt += mm_idx_to_prompt(placeholder_idx)
|
||||||
img_prompt += next_seg
|
img_prompt += next_seg
|
||||||
return img_prompt
|
return img_prompt
|
||||||
|
|
||||||
@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
|
|||||||
def get_model_prompts(base_prompts: Iterable[str],
|
def get_model_prompts(base_prompts: Iterable[str],
|
||||||
img_idx_to_prompt: Optional[Callable[[int], str]],
|
img_idx_to_prompt: Optional[Callable[[int], str]],
|
||||||
video_idx_to_prompt: Optional[Callable[[int], str]],
|
video_idx_to_prompt: Optional[Callable[[int], str]],
|
||||||
|
audio_idx_to_prompt: Optional[Callable[[int], str]],
|
||||||
prompt_formatter: Callable[[str], str]) -> list[str]:
|
prompt_formatter: Callable[[str], str]) -> list[str]:
|
||||||
"""Given a model-agnostic base prompt and test configuration for a model(s)
|
"""Given a model-agnostic base prompt and test configuration for a model(s)
|
||||||
to be tested, update the media placeholders and apply the prompt formatting
|
to be tested, update the media placeholders and apply the prompt formatting
|
||||||
@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str],
|
|||||||
video_idx_to_prompt,
|
video_idx_to_prompt,
|
||||||
TEST_VIDEO_PLACEHOLDER)
|
TEST_VIDEO_PLACEHOLDER)
|
||||||
|
|
||||||
|
if audio_idx_to_prompt:
|
||||||
|
base_prompt = replace_test_placeholder(base_prompt,
|
||||||
|
audio_idx_to_prompt,
|
||||||
|
TEST_AUDIO_PLACEHOLDER)
|
||||||
|
|
||||||
# Apply the prompt formatter to wrap the base prompt with
|
# Apply the prompt formatter to wrap the base prompt with
|
||||||
# the correct media placeholders to get the model test prompt
|
# the correct media placeholders to get the model test prompt
|
||||||
model_prompt = prompt_formatter(base_prompt)
|
model_prompt = prompt_formatter(base_prompt)
|
||||||
@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str],
|
|||||||
|
|
||||||
|
|
||||||
def build_single_image_inputs_from_test_info(
|
def build_single_image_inputs_from_test_info(
|
||||||
test_info: VLMTestInfo,
|
test_info: VLMTestInfo,
|
||||||
image_assets: ImageTestAssets,
|
image_assets: ImageTestAssets,
|
||||||
size_wrapper: ImageSizeWrapper,
|
size_wrapper: ImageSizeWrapper,
|
||||||
tmp_path: Optional[PosixPath] = None):
|
tmp_path: Optional[PosixPath] = None,
|
||||||
|
) -> list[PromptWithMultiModalInput]:
|
||||||
if test_info.prompt_formatter is None:
|
if test_info.prompt_formatter is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prompt formatter must be set to build single image inputs")
|
"Prompt formatter must be set to build single image inputs")
|
||||||
@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info(
|
|||||||
model_prompts = get_model_prompts(test_info.single_image_prompts,
|
model_prompts = get_model_prompts(test_info.single_image_prompts,
|
||||||
test_info.img_idx_to_prompt,
|
test_info.img_idx_to_prompt,
|
||||||
test_info.video_idx_to_prompt,
|
test_info.video_idx_to_prompt,
|
||||||
|
test_info.audio_idx_to_prompt,
|
||||||
test_info.prompt_formatter)
|
test_info.prompt_formatter)
|
||||||
|
|
||||||
# For models that require a local path / URL encoded in the image; export
|
# For models that require a local path / URL encoded in the image; export
|
||||||
@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info(
|
|||||||
return build_single_image_inputs(images, model_prompts, size_wrapper)
|
return build_single_image_inputs(images, model_prompts, size_wrapper)
|
||||||
|
|
||||||
|
|
||||||
def build_single_image_inputs(images, model_prompts,
|
def build_single_image_inputs(
|
||||||
size_wrapper: ImageSizeWrapper):
|
images, model_prompts,
|
||||||
|
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
|
||||||
# For every image / prompt pair, get a pair containing two lists of
|
# For every image / prompt pair, get a pair containing two lists of
|
||||||
# length size_factors, where the first contains duplicates of the model
|
# length size_factors, where the first contains duplicates of the model
|
||||||
# prompt [str], and the second contains copies of the image after being
|
# prompt [str], and the second contains copies of the image after being
|
||||||
# scaled by one of the size factors.
|
# scaled by one of the size factors.
|
||||||
#
|
#
|
||||||
# NOTE: rescaling preserves the image aspect ratio.
|
# NOTE: rescaling preserves the image aspect ratio.
|
||||||
return [(
|
return [
|
||||||
[prompt for _ in size_wrapper.data],
|
PromptWithMultiModalInput(
|
||||||
[
|
prompts=[prompt for _ in size_wrapper.data],
|
||||||
apply_image_size_scaling(image, size, size_wrapper.type)
|
image_data=[
|
||||||
for size in size_wrapper.data
|
apply_image_size_scaling(image, size, size_wrapper.type)
|
||||||
],
|
for size in size_wrapper.data
|
||||||
) for image, prompt in zip(images, model_prompts)]
|
],
|
||||||
|
) for image, prompt in zip(images, model_prompts)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_multi_image_inputs_from_test_info(
|
def build_multi_image_inputs_from_test_info(
|
||||||
test_info: VLMTestInfo,
|
test_info: VLMTestInfo,
|
||||||
image_assets: ImageTestAssets,
|
image_assets: ImageTestAssets,
|
||||||
size_wrapper: ImageSizeWrapper,
|
size_wrapper: ImageSizeWrapper,
|
||||||
tmp_path: Optional[PosixPath] = None):
|
tmp_path: Optional[PosixPath] = None,
|
||||||
|
) -> list[PromptWithMultiModalInput]:
|
||||||
if test_info.prompt_formatter is None:
|
if test_info.prompt_formatter is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prompt formatter must be set to build multi image inputs")
|
"Prompt formatter must be set to build multi image inputs")
|
||||||
@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info(
|
|||||||
model_prompts = get_model_prompts([test_info.multi_image_prompt],
|
model_prompts = get_model_prompts([test_info.multi_image_prompt],
|
||||||
test_info.img_idx_to_prompt,
|
test_info.img_idx_to_prompt,
|
||||||
test_info.video_idx_to_prompt,
|
test_info.video_idx_to_prompt,
|
||||||
|
test_info.audio_idx_to_prompt,
|
||||||
test_info.prompt_formatter)
|
test_info.prompt_formatter)
|
||||||
|
|
||||||
if test_info.prompt_path_encoder is not None:
|
if test_info.prompt_path_encoder is not None:
|
||||||
@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_multi_image_inputs(image_lists, model_prompts,
|
def build_multi_image_inputs(
|
||||||
size_wrapper: ImageSizeWrapper):
|
image_lists, model_prompts,
|
||||||
return [(
|
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
|
||||||
[prompt for _ in size_wrapper.data],
|
return [
|
||||||
[[
|
PromptWithMultiModalInput(
|
||||||
apply_image_size_scaling(image, size, size_wrapper.type)
|
prompts=[prompt for _ in size_wrapper.data],
|
||||||
for image in images
|
image_data=[[
|
||||||
] for size in size_wrapper.data],
|
apply_image_size_scaling(image, size, size_wrapper.type)
|
||||||
) for images, prompt in zip(image_lists, model_prompts)]
|
for image in images
|
||||||
|
] for size in size_wrapper.data],
|
||||||
|
) for images, prompt in zip(image_lists, model_prompts)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_embedding_inputs_from_test_info(
|
def build_embedding_inputs_from_test_info(
|
||||||
@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info(
|
|||||||
SINGLE_IMAGE_BASE_PROMPTS,
|
SINGLE_IMAGE_BASE_PROMPTS,
|
||||||
test_info.img_idx_to_prompt,
|
test_info.img_idx_to_prompt,
|
||||||
test_info.video_idx_to_prompt,
|
test_info.video_idx_to_prompt,
|
||||||
|
test_info.audio_idx_to_prompt,
|
||||||
test_info.prompt_formatter,
|
test_info.prompt_formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,13 +215,14 @@ def build_video_inputs_from_test_info(
|
|||||||
video_assets: VideoTestAssets,
|
video_assets: VideoTestAssets,
|
||||||
size_wrapper: ImageSizeWrapper,
|
size_wrapper: ImageSizeWrapper,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
):
|
) -> list[PromptWithMultiModalInput]:
|
||||||
if test_info.prompt_formatter is None:
|
if test_info.prompt_formatter is None:
|
||||||
raise ValueError("Prompt formatter must be set to build video inputs")
|
raise ValueError("Prompt formatter must be set to build video inputs")
|
||||||
model_prompts = get_model_prompts(
|
model_prompts = get_model_prompts(
|
||||||
[VIDEO_BASE_PROMPT],
|
[VIDEO_BASE_PROMPT],
|
||||||
test_info.img_idx_to_prompt,
|
test_info.img_idx_to_prompt,
|
||||||
test_info.video_idx_to_prompt,
|
test_info.video_idx_to_prompt,
|
||||||
|
test_info.audio_idx_to_prompt,
|
||||||
test_info.prompt_formatter,
|
test_info.prompt_formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -213,10 +234,14 @@ def build_video_inputs_from_test_info(
|
|||||||
video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE
|
video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE
|
||||||
else rescale_video_size)
|
else rescale_video_size)
|
||||||
|
|
||||||
return [(
|
return [
|
||||||
[prompt for _ in size_wrapper.data],
|
PromptWithMultiModalInput(
|
||||||
[video_scaler(video, size) for size in size_wrapper.data],
|
prompts=[prompt for _ in size_wrapper.data],
|
||||||
) for video, prompt in zip(sampled_vids, model_prompts)]
|
video_data=[
|
||||||
|
video_scaler(video, size) for size in size_wrapper.data
|
||||||
|
],
|
||||||
|
) for video, prompt in zip(sampled_vids, model_prompts)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
|
def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
|
||||||
@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
|
|||||||
# We have a list of fixed sizes
|
# We have a list of fixed sizes
|
||||||
return image.resize(size)
|
return image.resize(size)
|
||||||
raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR")
|
raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR")
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_inputs_from_test_info(
|
||||||
|
test_info: VLMTestInfo,
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
) -> list[PromptWithMultiModalInput]:
|
||||||
|
if test_info.prompt_formatter is None:
|
||||||
|
raise ValueError("Prompt formatter must be set to build audio inputs")
|
||||||
|
model_prompts = get_model_prompts(
|
||||||
|
SINGLE_AUDIO_BASE_PROMPT,
|
||||||
|
test_info.img_idx_to_prompt,
|
||||||
|
test_info.video_idx_to_prompt,
|
||||||
|
test_info.audio_idx_to_prompt,
|
||||||
|
test_info.prompt_formatter,
|
||||||
|
)
|
||||||
|
resampler = AudioResampler(
|
||||||
|
target_sr=16000,
|
||||||
|
method="librosa",
|
||||||
|
)
|
||||||
|
audios = [asset.audio_and_sample_rate for asset in audio_assets]
|
||||||
|
resampled_audios = [(
|
||||||
|
resampler.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=sr,
|
||||||
|
),
|
||||||
|
int(resampler.target_sr),
|
||||||
|
) for audio, sr in audios]
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptWithMultiModalInput(
|
||||||
|
prompts=model_prompts,
|
||||||
|
audio_data=resampled_audios,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|||||||
@ -83,7 +83,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
|
|||||||
test_info.num_video_frames)
|
test_info.num_video_frames)
|
||||||
|
|
||||||
# No sizes passed for custom inputs, since inputs are directly provided
|
# No sizes passed for custom inputs, since inputs are directly provided
|
||||||
if test_type != VLMTestType.CUSTOM_INPUTS:
|
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
|
||||||
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
|
wrapped_sizes = get_wrapped_test_sizes(test_info, test_type)
|
||||||
if wrapped_sizes is None:
|
if wrapped_sizes is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -91,7 +91,7 @@ def get_parametrized_options(test_settings: dict[str, VLMTestInfo],
|
|||||||
iter_kwargs["size_wrapper"] = wrapped_sizes
|
iter_kwargs["size_wrapper"] = wrapped_sizes
|
||||||
|
|
||||||
#Otherwise expand the custom test options instead
|
#Otherwise expand the custom test options instead
|
||||||
else:
|
elif test_type == VLMTestType.CUSTOM_INPUTS:
|
||||||
if test_info.custom_test_opts is None:
|
if test_info.custom_test_opts is None:
|
||||||
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
|
raise ValueError("Test has type CUSTOM_INPUTS, but none given")
|
||||||
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
|
iter_kwargs["custom_test_opts"] = test_info.custom_test_opts
|
||||||
@ -136,8 +136,8 @@ def get_wrapped_test_sizes(
|
|||||||
ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor)
|
ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor)
|
||||||
for factor in EMBEDDING_SIZE_FACTORS
|
for factor in EMBEDDING_SIZE_FACTORS
|
||||||
])
|
])
|
||||||
# Custom inputs have preprocessed inputs
|
# Audio and Custom inputs have preprocessed inputs
|
||||||
elif test_type == VLMTestType.CUSTOM_INPUTS:
|
elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS):
|
||||||
return tuple()
|
return tuple()
|
||||||
|
|
||||||
size_factors = test_info.image_size_factors \
|
size_factors = test_info.image_size_factors \
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Core test implementation to be shared across modalities."""
|
"""Core test implementation to be shared across modalities."""
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
from vllm.config import TaskOption
|
from vllm.config import TaskOption
|
||||||
@ -11,14 +10,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|||||||
|
|
||||||
from .....conftest import HfRunner, VllmRunner
|
from .....conftest import HfRunner, VllmRunner
|
||||||
from ....registry import HF_EXAMPLE_MODELS
|
from ....registry import HF_EXAMPLE_MODELS
|
||||||
from .types import RunnerOutput
|
from .types import PromptWithMultiModalInput, RunnerOutput
|
||||||
|
|
||||||
|
|
||||||
def run_test(
|
def run_test(
|
||||||
*,
|
*,
|
||||||
hf_runner: type[HfRunner],
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]],
|
inputs: list[PromptWithMultiModalInput],
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
@ -38,7 +37,6 @@ def run_test(
|
|||||||
hf_model_kwargs: Optional[dict[str, Any]],
|
hf_model_kwargs: Optional[dict[str, Any]],
|
||||||
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
|
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
runner_mm_key: str = "images",
|
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[str] = None,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
vllm_embeddings: Optional[torch.Tensor] = None,
|
vllm_embeddings: Optional[torch.Tensor] = None,
|
||||||
@ -94,10 +92,16 @@ def run_test(
|
|||||||
if stop_str:
|
if stop_str:
|
||||||
vllm_kwargs["stop"] = stop_str
|
vllm_kwargs["stop"] = stop_str
|
||||||
|
|
||||||
for prompts, media in vllm_inputs:
|
for prompts, image_data, video_data, audio_data in vllm_inputs:
|
||||||
vllm_kwargs[runner_mm_key] = media
|
mm_data = dict(images=image_data,
|
||||||
|
videos=video_data,
|
||||||
|
audios=audio_data)
|
||||||
|
vllm_kwargs_with_mm_data = vllm_kwargs | mm_data
|
||||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||||
prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs)
|
prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
**vllm_kwargs_with_mm_data)
|
||||||
vllm_outputs_per_mm.append(vllm_output)
|
vllm_outputs_per_mm.append(vllm_output)
|
||||||
|
|
||||||
hf_model = hf_runner(model,
|
hf_model = hf_runner(model,
|
||||||
@ -122,14 +126,17 @@ def run_test(
|
|||||||
if stop_str:
|
if stop_str:
|
||||||
hf_kwargs["stop_strings"] = stop_str
|
hf_kwargs["stop_strings"] = stop_str
|
||||||
|
|
||||||
for prompts, media in inputs:
|
for prompts, image_data, video_data, audio_data in inputs:
|
||||||
hf_kwargs[runner_mm_key] = media
|
mm_data = dict(images=image_data,
|
||||||
|
videos=video_data,
|
||||||
|
audios=audio_data)
|
||||||
|
hf_kwargs_with_mm_data = hf_kwargs | mm_data
|
||||||
hf_output = hf_model.generate_greedy_logprobs_limit(
|
hf_output = hf_model.generate_greedy_logprobs_limit(
|
||||||
prompts,
|
prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
**hf_kwargs)
|
**hf_kwargs_with_mm_data)
|
||||||
hf_outputs_per_mm.append(hf_output)
|
hf_outputs_per_mm.append(hf_output)
|
||||||
|
|
||||||
# Apply output processing / sanitation to the vLLM and HF runner results
|
# Apply output processing / sanitation to the vLLM and HF runner results
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.multimodal.video import (rescale_video_size, resize_video,
|
|||||||
|
|
||||||
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
|
from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
|
||||||
from .builders import build_multi_image_inputs, build_single_image_inputs
|
from .builders import build_multi_image_inputs, build_single_image_inputs
|
||||||
from .types import ImageSizeWrapper, SizeType
|
from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType
|
||||||
|
|
||||||
|
|
||||||
def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
|
def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
|
||||||
@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
|
|||||||
"<image>\nWhat is the season?",
|
"<image>\nWhat is the season?",
|
||||||
]
|
]
|
||||||
formatted_prompts = [formatter(prompt) for prompt in img_prompts]
|
formatted_prompts = [formatter(prompt) for prompt in img_prompts]
|
||||||
|
aspect_ratio_images = [
|
||||||
return [(
|
[stop_sign, cherry_blossom],
|
||||||
formatted_prompts,
|
# Images with different sizes and aspect-ratios
|
||||||
[
|
[
|
||||||
[stop_sign, cherry_blossom],
|
rescale_image_size(stop_sign, 0.1),
|
||||||
# Images with different sizes and aspect-ratios
|
stop_sign,
|
||||||
[
|
],
|
||||||
rescale_image_size(stop_sign, 0.1),
|
[
|
||||||
stop_sign,
|
stop_sign,
|
||||||
],
|
rescale_image_size(stop_sign, 0.25),
|
||||||
[
|
cherry_blossom.resize((183, 488)),
|
||||||
stop_sign,
|
cherry_blossom.resize((488, 183))
|
||||||
rescale_image_size(stop_sign, 0.25),
|
],
|
||||||
cherry_blossom.resize((183, 488)),
|
cherry_blossom,
|
||||||
cherry_blossom.resize((488, 183))
|
]
|
||||||
],
|
|
||||||
cherry_blossom,
|
return [
|
||||||
])]
|
PromptWithMultiModalInput(
|
||||||
|
prompts=formatted_prompts,
|
||||||
|
image_data=aspect_ratio_images,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
|
def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
|
||||||
@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
|
|||||||
"<video>\nWhy is this video funny?",
|
"<video>\nWhy is this video funny?",
|
||||||
]
|
]
|
||||||
formatted_prompts = [formatter(prompt) for prompt in video_prompts]
|
formatted_prompts = [formatter(prompt) for prompt in video_prompts]
|
||||||
|
aspect_ratio_videos = [
|
||||||
return [(
|
[video, video],
|
||||||
formatted_prompts,
|
# Videos with different sizes and aspect-ratios
|
||||||
[
|
[
|
||||||
[video, video],
|
rescale_video_size(video, 0.1),
|
||||||
# Videos with different sizes and aspect-ratios
|
|
||||||
[
|
|
||||||
rescale_video_size(video, 0.1),
|
|
||||||
video,
|
|
||||||
],
|
|
||||||
[
|
|
||||||
video,
|
|
||||||
rescale_video_size(video, 0.25),
|
|
||||||
resize_video(video, (183, 488)),
|
|
||||||
resize_video(video, (488, 183))
|
|
||||||
],
|
|
||||||
video,
|
video,
|
||||||
])]
|
],
|
||||||
|
[
|
||||||
|
video,
|
||||||
|
rescale_video_size(video, 0.25),
|
||||||
|
resize_video(video, (183, 488)),
|
||||||
|
resize_video(video, (488, 183))
|
||||||
|
],
|
||||||
|
video,
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptWithMultiModalInput(
|
||||||
|
prompts=formatted_prompts,
|
||||||
|
video_data=aspect_ratio_videos,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def different_patch_input_cases_internvl():
|
def different_patch_input_cases_internvl():
|
||||||
|
|||||||
@ -237,6 +237,18 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput,
|
|||||||
return output_ids, output_str, out_logprobs
|
return output_ids, output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def ultravox_trunc_hf_output(hf_output: RunnerOutput,
|
||||||
|
model: str) -> RunnerOutput:
|
||||||
|
output_ids, output_str, out_logprobs = hf_output
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
eos_token = tokenizer.decode(eos_token_id)
|
||||||
|
if output_str.endswith(eos_token):
|
||||||
|
output_str = output_str.split(eos_token)[0]
|
||||||
|
return output_ids, output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
####### Functions for converting image assets to embeddings
|
####### Functions for converting image assets to embeddings
|
||||||
def get_llava_embeddings(image_assets: ImageTestAssets):
|
def get_llava_embeddings(image_assets: ImageTestAssets):
|
||||||
return [asset.image_embeds for asset in image_assets]
|
return [asset.image_embeds for asset in image_assets]
|
||||||
|
|||||||
@ -4,8 +4,8 @@ types / modalities.
|
|||||||
"""
|
"""
|
||||||
from pathlib import PosixPath
|
from pathlib import PosixPath
|
||||||
|
|
||||||
from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets,
|
from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets,
|
||||||
VllmRunner)
|
VideoTestAssets, VllmRunner)
|
||||||
from . import builders, core
|
from . import builders, core
|
||||||
from .types import ExpandableVLMTestArgs, VLMTestInfo
|
from .types import ExpandableVLMTestArgs, VLMTestInfo
|
||||||
|
|
||||||
@ -30,7 +30,6 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
|
|||||||
num_logprobs=test_case.num_logprobs,
|
num_logprobs=test_case.num_logprobs,
|
||||||
limit_mm_per_prompt={"image": 1},
|
limit_mm_per_prompt={"image": 1},
|
||||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
runner_mm_key="images",
|
|
||||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|
||||||
|
|
||||||
@ -53,7 +52,6 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
|
|||||||
num_logprobs=test_case.num_logprobs,
|
num_logprobs=test_case.num_logprobs,
|
||||||
limit_mm_per_prompt={"image": len(image_assets)},
|
limit_mm_per_prompt={"image": len(image_assets)},
|
||||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
runner_mm_key="images",
|
|
||||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +75,6 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
|
|||||||
limit_mm_per_prompt={"image": 1},
|
limit_mm_per_prompt={"image": 1},
|
||||||
vllm_embeddings=vllm_embeddings,
|
vllm_embeddings=vllm_embeddings,
|
||||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
runner_mm_key="images",
|
|
||||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +102,30 @@ def run_video_test(
|
|||||||
num_logprobs=test_case.num_logprobs,
|
num_logprobs=test_case.num_logprobs,
|
||||||
limit_mm_per_prompt={"video": len(video_assets)},
|
limit_mm_per_prompt={"video": len(video_assets)},
|
||||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
runner_mm_key="videos",
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|
||||||
|
|
||||||
|
def run_audio_test(
|
||||||
|
*,
|
||||||
|
model_test_info: VLMTestInfo,
|
||||||
|
test_case: ExpandableVLMTestArgs,
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
inputs = builders.build_audio_inputs_from_test_info(
|
||||||
|
model_test_info, audio_assets)
|
||||||
|
|
||||||
|
core.run_test(
|
||||||
|
hf_runner=hf_runner,
|
||||||
|
vllm_runner=vllm_runner,
|
||||||
|
inputs=inputs,
|
||||||
|
model=test_case.model,
|
||||||
|
dtype=test_case.dtype,
|
||||||
|
max_tokens=test_case.max_tokens,
|
||||||
|
num_logprobs=test_case.num_logprobs,
|
||||||
|
limit_mm_per_prompt={"audio": 1},
|
||||||
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|
||||||
|
|
||||||
@ -120,11 +140,9 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
|
|||||||
|
|
||||||
inputs = test_case.custom_test_opts.inputs
|
inputs = test_case.custom_test_opts.inputs
|
||||||
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
|
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
|
||||||
runner_mm_key = test_case.custom_test_opts.runner_mm_key
|
# Inputs and limit_mm_per_prompt should all be set
|
||||||
# Inputs, limit_mm_per_prompt, and runner_mm_key should all be set
|
|
||||||
assert inputs is not None
|
assert inputs is not None
|
||||||
assert limit_mm_per_prompt is not None
|
assert limit_mm_per_prompt is not None
|
||||||
assert runner_mm_key is not None
|
|
||||||
|
|
||||||
core.run_test(
|
core.run_test(
|
||||||
hf_runner=hf_runner,
|
hf_runner=hf_runner,
|
||||||
@ -136,5 +154,4 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
|
|||||||
num_logprobs=test_case.num_logprobs,
|
num_logprobs=test_case.num_logprobs,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
distributed_executor_backend=test_case.distributed_executor_backend,
|
distributed_executor_backend=test_case.distributed_executor_backend,
|
||||||
runner_mm_key=runner_mm_key,
|
|
||||||
**model_test_info.get_non_parametrized_runner_kwargs())
|
**model_test_info.get_non_parametrized_runner_kwargs())
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from pathlib import PosixPath
|
|||||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
|
||||||
from pytest import MarkDecorator
|
from pytest import MarkDecorator
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
@ -15,18 +14,25 @@ from vllm.config import TaskOption
|
|||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets
|
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
|
||||||
|
ImageTestAssets, PromptAudioInput, PromptImageInput,
|
||||||
|
PromptVideoInput)
|
||||||
from ....utils import check_logprobs_close
|
from ....utils import check_logprobs_close
|
||||||
|
|
||||||
# meta image tag; will be replaced by the appropriate tag for the model
|
# meta image tag; will be replaced by the appropriate tag for the model
|
||||||
TEST_IMG_PLACEHOLDER = "<vlm_image>"
|
TEST_IMG_PLACEHOLDER = "<vlm_image>"
|
||||||
TEST_VIDEO_PLACEHOLDER = "<vlm_video>"
|
TEST_VIDEO_PLACEHOLDER = "<vlm_video>"
|
||||||
|
TEST_AUDIO_PLACEHOLDER = "<lmm_audio>"
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({
|
SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
"stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?",
|
"stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?",
|
||||||
"cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?",
|
"cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?",
|
||||||
})
|
})
|
||||||
|
SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({
|
||||||
|
"mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501
|
||||||
|
"winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501
|
||||||
|
})
|
||||||
|
|
||||||
MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501
|
MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501
|
||||||
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
|
VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?"
|
||||||
@ -38,12 +44,21 @@ RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]]
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
class PromptWithMultiModalInput(NamedTuple):
|
||||||
|
"""Holds the multimodal input for a single test case."""
|
||||||
|
prompts: list[str]
|
||||||
|
image_data: Optional[PromptImageInput] = None
|
||||||
|
video_data: Optional[PromptVideoInput] = None
|
||||||
|
audio_data: Optional[PromptAudioInput] = None
|
||||||
|
|
||||||
|
|
||||||
class VLMTestType(Enum):
|
class VLMTestType(Enum):
|
||||||
IMAGE = 1
|
IMAGE = 1
|
||||||
MULTI_IMAGE = 2
|
MULTI_IMAGE = 2
|
||||||
EMBEDDING = 3
|
EMBEDDING = 3
|
||||||
VIDEO = 4
|
VIDEO = 4
|
||||||
CUSTOM_INPUTS = 5
|
AUDIO = 5
|
||||||
|
CUSTOM_INPUTS = 6
|
||||||
|
|
||||||
|
|
||||||
class SizeType(Enum):
|
class SizeType(Enum):
|
||||||
@ -52,10 +67,8 @@ class SizeType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class CustomTestOptions(NamedTuple):
|
class CustomTestOptions(NamedTuple):
|
||||||
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]]
|
inputs: list[PromptWithMultiModalInput]
|
||||||
limit_mm_per_prompt: dict[str, int]
|
limit_mm_per_prompt: dict[str, int]
|
||||||
# kwarg to pass multimodal data in as to vllm/hf runner instances.
|
|
||||||
runner_mm_key: str = "images"
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSizeWrapper(NamedTuple):
|
class ImageSizeWrapper(NamedTuple):
|
||||||
@ -75,6 +88,7 @@ class VLMTestInfo(NamedTuple):
|
|||||||
prompt_formatter: Optional[Callable[[str], str]] = None
|
prompt_formatter: Optional[Callable[[str], str]] = None
|
||||||
img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n"
|
img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n"
|
||||||
video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n"
|
video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n"
|
||||||
|
audio_idx_to_prompt: Callable[[int], str] = lambda idx: "<audio>\n"
|
||||||
|
|
||||||
# Most models work on the single / multi-image prompts above, but in some
|
# Most models work on the single / multi-image prompts above, but in some
|
||||||
# cases the log prob check fails, e.g., for paligemma. We allow passing
|
# cases the log prob check fails, e.g., for paligemma. We allow passing
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user