[VLM] Add Qwen3-VL generation test (#25185)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-10-29 20:19:37 +08:00 committed by GitHub
parent 3481e40743
commit ad3ec89532
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 108 additions and 5 deletions

View File

@ -159,6 +159,28 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen3_vl": VLMTestInfo(
models=["Qwen/Qwen3-VL-4B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO,
),
needs_video_metadata=True,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
num_logprobs=20,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[
pytest.mark.core_model,
],
),
"ultravox": VLMTestInfo(
models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
test_type=VLMTestType.AUDIO,

View File

@ -4,7 +4,9 @@
from collections.abc import Callable, Iterable
from pathlib import PosixPath
from typing import Any
import numpy.typing as npt
import torch
from vllm.multimodal.audio import AudioResampler
@ -236,6 +238,7 @@ def build_video_inputs_from_test_info(
video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper,
num_frames: int,
needs_video_metadata: bool,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build video inputs")
@ -248,7 +251,10 @@ def build_video_inputs_from_test_info(
)
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
sample_frames_with_video_metadata(
(asset.np_ndarrays, asset.metadata),
num_frames,
)
for asset in video_assets
]
@ -259,12 +265,33 @@ def build_video_inputs_from_test_info(
return [
PromptWithMultiModalInput(
prompts=[prompt for _ in size_wrapper.data],
video_data=[video_scaler(video, size) for size in size_wrapper.data],
video_data=[
(
video_scaler(video, size)
if not needs_video_metadata
else (video_scaler(video, size), meta)
)
for size in size_wrapper.data
],
)
for video, prompt in zip(sampled_vids, model_prompts)
for (video, meta), prompt in zip(sampled_vids, model_prompts)
]
def sample_frames_with_video_metadata(
video_with_meta: tuple[npt.NDArray, dict[str, Any]],
num_frames: int,
) -> tuple[npt.NDArray, dict[str, Any]]:
video, meta = video_with_meta
video = sample_frames_from_video(video, num_frames)
meta["do_sample_frames"] = meta["total_num_frames"] == num_frames
meta["total_num_frames"] = num_frames
meta["fps"] = meta["duration"] / num_frames
meta["frames_indices"] = list(range(num_frames))
return video, meta
def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType):
"""Applies a size scaler to one image; this can be an image size factor,
which scales the image while maintaining the aspect ratio"""

View File

@ -100,6 +100,9 @@ def get_parametrized_options(
# num_frames is video only
if test_type == VLMTestType.VIDEO:
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
test_info.needs_video_metadata
)
# No sizes passed for custom inputs, since inputs are directly provided
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):

View File

@ -905,6 +905,54 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def qwen3_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4.1V."""
hf_processor = hf_model.processor
def processor(*args, videos=None, **kwargs):
if videos is not None and is_list_of(videos, tuple):
# batched multi videos
do_sample_frames = {video[1]["do_sample_frames"] for video in videos}
assert len(do_sample_frames) == 1
if kwargs.get("do_sample_frames") is None:
kwargs["do_sample_frames"] = do_sample_frames
video_metadata = [
[
VideoMetadata(
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
)
]
for video in videos
]
videos = [[video[0]] for video in videos]
elif videos is not None and isinstance(videos, tuple):
# single video
do_sample_frames = videos[1]["do_sample_frames"]
if kwargs.get("do_sample_frames") is None:
kwargs["do_sample_frames"] = do_sample_frames
video_metadata = [
[
VideoMetadata(
**{
k: v
for k, v in videos[1].items()
if k != "do_sample_frames"
}
)
]
]
videos = [[videos[0]]]
else:
video_metadata = None
return hf_processor(
*args, videos=videos, video_metadata=video_metadata, **kwargs
)
hf_model.processor = processor
return hf_model
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from vllm.model_executor.models.tarsier import get_vision_encoder_info

View File

@ -117,6 +117,7 @@ def run_video_test(
video_assets,
test_case.size_wrapper,
test_case.num_video_frames,
test_case.needs_video_metadata,
)
core.run_test(

View File

@ -154,7 +154,8 @@ class VLMTestInfo(NamedTuple):
dtype: str = "auto"
distributed_executor_backend: str | None = None
# Only expanded in video tests
num_video_frames: int = 16
num_video_frames: int | tuple[int] = 16
needs_video_metadata: bool = False
# Fixed image sizes / image size factors; most tests use image_size_factors
# The values provided for these two fields will be stacked and expanded
@ -212,5 +213,6 @@ class ExpandableVLMTestArgs(NamedTuple):
size_wrapper: ImageSizeWrapper | None = None
# Video only
num_video_frames: int | None = None
needs_video_metadata: bool = False
# Custom inputs only
custom_test_opts: CustomTestOptions | None = None

View File

@ -94,7 +94,7 @@ def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
metadata = {
"total_num_frames": num_frames,
"fps": fps,
"fps": duration / num_frames,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),