mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:15:56 +08:00
[VLM] Add Qwen3-VL generation test (#25185)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
3481e40743
commit
ad3ec89532
@ -159,6 +159,28 @@ VLM_TEST_SETTINGS = {
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
"qwen3_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen3-VL-4B-Instruct"],
|
||||
test_type=(
|
||||
VLMTestType.IMAGE,
|
||||
VLMTestType.MULTI_IMAGE,
|
||||
VLMTestType.VIDEO,
|
||||
),
|
||||
needs_video_metadata=True,
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
|
||||
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
num_logprobs=20,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[
|
||||
pytest.mark.core_model,
|
||||
],
|
||||
),
|
||||
"ultravox": VLMTestInfo(
|
||||
models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
|
||||
test_type=VLMTestType.AUDIO,
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from pathlib import PosixPath
|
||||
from typing import Any
|
||||
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.audio import AudioResampler
|
||||
@ -236,6 +238,7 @@ def build_video_inputs_from_test_info(
|
||||
video_assets: VideoTestAssets,
|
||||
size_wrapper: ImageSizeWrapper,
|
||||
num_frames: int,
|
||||
needs_video_metadata: bool,
|
||||
) -> list[PromptWithMultiModalInput]:
|
||||
if test_info.prompt_formatter is None:
|
||||
raise ValueError("Prompt formatter must be set to build video inputs")
|
||||
@ -248,7 +251,10 @@ def build_video_inputs_from_test_info(
|
||||
)
|
||||
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
sample_frames_with_video_metadata(
|
||||
(asset.np_ndarrays, asset.metadata),
|
||||
num_frames,
|
||||
)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
@ -259,12 +265,33 @@ def build_video_inputs_from_test_info(
|
||||
return [
|
||||
PromptWithMultiModalInput(
|
||||
prompts=[prompt for _ in size_wrapper.data],
|
||||
video_data=[video_scaler(video, size) for size in size_wrapper.data],
|
||||
video_data=[
|
||||
(
|
||||
video_scaler(video, size)
|
||||
if not needs_video_metadata
|
||||
else (video_scaler(video, size), meta)
|
||||
)
|
||||
for size in size_wrapper.data
|
||||
],
|
||||
)
|
||||
for video, prompt in zip(sampled_vids, model_prompts)
|
||||
for (video, meta), prompt in zip(sampled_vids, model_prompts)
|
||||
]
|
||||
|
||||
|
||||
def sample_frames_with_video_metadata(
|
||||
video_with_meta: tuple[npt.NDArray, dict[str, Any]],
|
||||
num_frames: int,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
video, meta = video_with_meta
|
||||
video = sample_frames_from_video(video, num_frames)
|
||||
|
||||
meta["do_sample_frames"] = meta["total_num_frames"] == num_frames
|
||||
meta["total_num_frames"] = num_frames
|
||||
meta["fps"] = meta["duration"] / num_frames
|
||||
meta["frames_indices"] = list(range(num_frames))
|
||||
return video, meta
|
||||
|
||||
|
||||
def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType):
|
||||
"""Applies a size scaler to one image; this can be an image size factor,
|
||||
which scales the image while maintaining the aspect ratio"""
|
||||
|
||||
@ -100,6 +100,9 @@ def get_parametrized_options(
|
||||
# num_frames is video only
|
||||
if test_type == VLMTestType.VIDEO:
|
||||
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
|
||||
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
|
||||
test_info.needs_video_metadata
|
||||
)
|
||||
|
||||
# No sizes passed for custom inputs, since inputs are directly provided
|
||||
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
|
||||
|
||||
@ -905,6 +905,54 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def qwen3_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for GLM4.1V."""
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
def processor(*args, videos=None, **kwargs):
|
||||
if videos is not None and is_list_of(videos, tuple):
|
||||
# batched multi videos
|
||||
do_sample_frames = {video[1]["do_sample_frames"] for video in videos}
|
||||
assert len(do_sample_frames) == 1
|
||||
if kwargs.get("do_sample_frames") is None:
|
||||
kwargs["do_sample_frames"] = do_sample_frames
|
||||
video_metadata = [
|
||||
[
|
||||
VideoMetadata(
|
||||
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
|
||||
)
|
||||
]
|
||||
for video in videos
|
||||
]
|
||||
videos = [[video[0]] for video in videos]
|
||||
elif videos is not None and isinstance(videos, tuple):
|
||||
# single video
|
||||
do_sample_frames = videos[1]["do_sample_frames"]
|
||||
if kwargs.get("do_sample_frames") is None:
|
||||
kwargs["do_sample_frames"] = do_sample_frames
|
||||
video_metadata = [
|
||||
[
|
||||
VideoMetadata(
|
||||
**{
|
||||
k: v
|
||||
for k, v in videos[1].items()
|
||||
if k != "do_sample_frames"
|
||||
}
|
||||
)
|
||||
]
|
||||
]
|
||||
videos = [[videos[0]]]
|
||||
else:
|
||||
video_metadata = None
|
||||
|
||||
return hf_processor(
|
||||
*args, videos=videos, video_metadata=video_metadata, **kwargs
|
||||
)
|
||||
|
||||
hf_model.processor = processor
|
||||
return hf_model
|
||||
|
||||
|
||||
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
from vllm.model_executor.models.tarsier import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -117,6 +117,7 @@ def run_video_test(
|
||||
video_assets,
|
||||
test_case.size_wrapper,
|
||||
test_case.num_video_frames,
|
||||
test_case.needs_video_metadata,
|
||||
)
|
||||
|
||||
core.run_test(
|
||||
|
||||
@ -154,7 +154,8 @@ class VLMTestInfo(NamedTuple):
|
||||
dtype: str = "auto"
|
||||
distributed_executor_backend: str | None = None
|
||||
# Only expanded in video tests
|
||||
num_video_frames: int = 16
|
||||
num_video_frames: int | tuple[int] = 16
|
||||
needs_video_metadata: bool = False
|
||||
|
||||
# Fixed image sizes / image size factors; most tests use image_size_factors
|
||||
# The values provided for these two fields will be stacked and expanded
|
||||
@ -212,5 +213,6 @@ class ExpandableVLMTestArgs(NamedTuple):
|
||||
size_wrapper: ImageSizeWrapper | None = None
|
||||
# Video only
|
||||
num_video_frames: int | None = None
|
||||
needs_video_metadata: bool = False
|
||||
# Custom inputs only
|
||||
custom_test_opts: CustomTestOptions | None = None
|
||||
|
||||
@ -94,7 +94,7 @@ def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": fps,
|
||||
"fps": duration / num_frames,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user