[VLM] Add Qwen3-VL generation test (#25185)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-10-29 20:19:37 +08:00 committed by GitHub
parent 3481e40743
commit ad3ec89532
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 108 additions and 5 deletions

View File

@ -159,6 +159,28 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
"qwen3_vl": VLMTestInfo(
models=["Qwen/Qwen3-VL-4B-Instruct"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO,
),
needs_video_metadata=True,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
num_logprobs=20,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[
pytest.mark.core_model,
],
),
"ultravox": VLMTestInfo( "ultravox": VLMTestInfo(
models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"], models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
test_type=VLMTestType.AUDIO, test_type=VLMTestType.AUDIO,

View File

@ -4,7 +4,9 @@
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from pathlib import PosixPath from pathlib import PosixPath
from typing import Any
import numpy.typing as npt
import torch import torch
from vllm.multimodal.audio import AudioResampler from vllm.multimodal.audio import AudioResampler
@ -236,6 +238,7 @@ def build_video_inputs_from_test_info(
video_assets: VideoTestAssets, video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper, size_wrapper: ImageSizeWrapper,
num_frames: int, num_frames: int,
needs_video_metadata: bool,
) -> list[PromptWithMultiModalInput]: ) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None: if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build video inputs") raise ValueError("Prompt formatter must be set to build video inputs")
@ -248,7 +251,10 @@ def build_video_inputs_from_test_info(
) )
sampled_vids = [ sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames) sample_frames_with_video_metadata(
(asset.np_ndarrays, asset.metadata),
num_frames,
)
for asset in video_assets for asset in video_assets
] ]
@ -259,12 +265,33 @@ def build_video_inputs_from_test_info(
return [ return [
PromptWithMultiModalInput( PromptWithMultiModalInput(
prompts=[prompt for _ in size_wrapper.data], prompts=[prompt for _ in size_wrapper.data],
video_data=[video_scaler(video, size) for size in size_wrapper.data], video_data=[
(
video_scaler(video, size)
if not needs_video_metadata
else (video_scaler(video, size), meta)
)
for size in size_wrapper.data
],
) )
for video, prompt in zip(sampled_vids, model_prompts) for (video, meta), prompt in zip(sampled_vids, model_prompts)
] ]
def sample_frames_with_video_metadata(
video_with_meta: tuple[npt.NDArray, dict[str, Any]],
num_frames: int,
) -> tuple[npt.NDArray, dict[str, Any]]:
video, meta = video_with_meta
video = sample_frames_from_video(video, num_frames)
meta["do_sample_frames"] = meta["total_num_frames"] == num_frames
meta["total_num_frames"] = num_frames
meta["fps"] = meta["duration"] / num_frames
meta["frames_indices"] = list(range(num_frames))
return video, meta
def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType): def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType):
"""Applies a size scaler to one image; this can be an image size factor, """Applies a size scaler to one image; this can be an image size factor,
which scales the image while maintaining the aspect ratio""" which scales the image while maintaining the aspect ratio"""

View File

@ -100,6 +100,9 @@ def get_parametrized_options(
# num_frames is video only # num_frames is video only
if test_type == VLMTestType.VIDEO: if test_type == VLMTestType.VIDEO:
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
test_info.needs_video_metadata
)
# No sizes passed for custom inputs, since inputs are directly provided # No sizes passed for custom inputs, since inputs are directly provided
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):

View File

@ -905,6 +905,54 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def qwen3_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4.1V."""
hf_processor = hf_model.processor
def processor(*args, videos=None, **kwargs):
if videos is not None and is_list_of(videos, tuple):
# batched multi videos
do_sample_frames = {video[1]["do_sample_frames"] for video in videos}
assert len(do_sample_frames) == 1
if kwargs.get("do_sample_frames") is None:
kwargs["do_sample_frames"] = do_sample_frames
video_metadata = [
[
VideoMetadata(
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
)
]
for video in videos
]
videos = [[video[0]] for video in videos]
elif videos is not None and isinstance(videos, tuple):
# single video
do_sample_frames = videos[1]["do_sample_frames"]
if kwargs.get("do_sample_frames") is None:
kwargs["do_sample_frames"] = do_sample_frames
video_metadata = [
[
VideoMetadata(
**{
k: v
for k, v in videos[1].items()
if k != "do_sample_frames"
}
)
]
]
videos = [[videos[0]]]
else:
video_metadata = None
return hf_processor(
*args, videos=videos, video_metadata=video_metadata, **kwargs
)
hf_model.processor = processor
return hf_model
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner: def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from vllm.model_executor.models.tarsier import get_vision_encoder_info from vllm.model_executor.models.tarsier import get_vision_encoder_info

View File

@ -117,6 +117,7 @@ def run_video_test(
video_assets, video_assets,
test_case.size_wrapper, test_case.size_wrapper,
test_case.num_video_frames, test_case.num_video_frames,
test_case.needs_video_metadata,
) )
core.run_test( core.run_test(

View File

@ -154,7 +154,8 @@ class VLMTestInfo(NamedTuple):
dtype: str = "auto" dtype: str = "auto"
distributed_executor_backend: str | None = None distributed_executor_backend: str | None = None
# Only expanded in video tests # Only expanded in video tests
num_video_frames: int = 16 num_video_frames: int | tuple[int] = 16
needs_video_metadata: bool = False
# Fixed image sizes / image size factors; most tests use image_size_factors # Fixed image sizes / image size factors; most tests use image_size_factors
# The values provided for these two fields will be stacked and expanded # The values provided for these two fields will be stacked and expanded
@ -212,5 +213,6 @@ class ExpandableVLMTestArgs(NamedTuple):
size_wrapper: ImageSizeWrapper | None = None size_wrapper: ImageSizeWrapper | None = None
# Video only # Video only
num_video_frames: int | None = None num_video_frames: int | None = None
needs_video_metadata: bool = False
# Custom inputs only # Custom inputs only
custom_test_opts: CustomTestOptions | None = None custom_test_opts: CustomTestOptions | None = None

View File

@ -94,7 +94,7 @@ def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
metadata = { metadata = {
"total_num_frames": num_frames, "total_num_frames": num_frames,
"fps": fps, "fps": duration / num_frames,
"duration": duration, "duration": duration,
"video_backend": "opencv", "video_backend": "opencv",
"frames_indices": list(range(num_frames)), "frames_indices": list(range(num_frames)),