mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[VLM] Add Qwen3-VL generation test (#25185)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
3481e40743
commit
ad3ec89532
@ -159,6 +159,28 @@ VLM_TEST_SETTINGS = {
|
|||||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
),
|
),
|
||||||
|
"qwen3_vl": VLMTestInfo(
|
||||||
|
models=["Qwen/Qwen3-VL-4B-Instruct"],
|
||||||
|
test_type=(
|
||||||
|
VLMTestType.IMAGE,
|
||||||
|
VLMTestType.MULTI_IMAGE,
|
||||||
|
VLMTestType.VIDEO,
|
||||||
|
),
|
||||||
|
needs_video_metadata=True,
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
|
||||||
|
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
num_logprobs=20,
|
||||||
|
auto_cls=AutoModelForImageTextToText,
|
||||||
|
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||||
|
patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner,
|
||||||
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
|
marks=[
|
||||||
|
pytest.mark.core_model,
|
||||||
|
],
|
||||||
|
),
|
||||||
"ultravox": VLMTestInfo(
|
"ultravox": VLMTestInfo(
|
||||||
models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
|
models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
|
||||||
test_type=VLMTestType.AUDIO,
|
test_type=VLMTestType.AUDIO,
|
||||||
|
|||||||
@ -4,7 +4,9 @@
|
|||||||
|
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from pathlib import PosixPath
|
from pathlib import PosixPath
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.multimodal.audio import AudioResampler
|
from vllm.multimodal.audio import AudioResampler
|
||||||
@ -236,6 +238,7 @@ def build_video_inputs_from_test_info(
|
|||||||
video_assets: VideoTestAssets,
|
video_assets: VideoTestAssets,
|
||||||
size_wrapper: ImageSizeWrapper,
|
size_wrapper: ImageSizeWrapper,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
|
needs_video_metadata: bool,
|
||||||
) -> list[PromptWithMultiModalInput]:
|
) -> list[PromptWithMultiModalInput]:
|
||||||
if test_info.prompt_formatter is None:
|
if test_info.prompt_formatter is None:
|
||||||
raise ValueError("Prompt formatter must be set to build video inputs")
|
raise ValueError("Prompt formatter must be set to build video inputs")
|
||||||
@ -248,7 +251,10 @@ def build_video_inputs_from_test_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
sampled_vids = [
|
sampled_vids = [
|
||||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
sample_frames_with_video_metadata(
|
||||||
|
(asset.np_ndarrays, asset.metadata),
|
||||||
|
num_frames,
|
||||||
|
)
|
||||||
for asset in video_assets
|
for asset in video_assets
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -259,12 +265,33 @@ def build_video_inputs_from_test_info(
|
|||||||
return [
|
return [
|
||||||
PromptWithMultiModalInput(
|
PromptWithMultiModalInput(
|
||||||
prompts=[prompt for _ in size_wrapper.data],
|
prompts=[prompt for _ in size_wrapper.data],
|
||||||
video_data=[video_scaler(video, size) for size in size_wrapper.data],
|
video_data=[
|
||||||
|
(
|
||||||
|
video_scaler(video, size)
|
||||||
|
if not needs_video_metadata
|
||||||
|
else (video_scaler(video, size), meta)
|
||||||
|
)
|
||||||
|
for size in size_wrapper.data
|
||||||
|
],
|
||||||
)
|
)
|
||||||
for video, prompt in zip(sampled_vids, model_prompts)
|
for (video, meta), prompt in zip(sampled_vids, model_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_frames_with_video_metadata(
|
||||||
|
video_with_meta: tuple[npt.NDArray, dict[str, Any]],
|
||||||
|
num_frames: int,
|
||||||
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||||
|
video, meta = video_with_meta
|
||||||
|
video = sample_frames_from_video(video, num_frames)
|
||||||
|
|
||||||
|
meta["do_sample_frames"] = meta["total_num_frames"] == num_frames
|
||||||
|
meta["total_num_frames"] = num_frames
|
||||||
|
meta["fps"] = meta["duration"] / num_frames
|
||||||
|
meta["frames_indices"] = list(range(num_frames))
|
||||||
|
return video, meta
|
||||||
|
|
||||||
|
|
||||||
def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType):
|
def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType):
|
||||||
"""Applies a size scaler to one image; this can be an image size factor,
|
"""Applies a size scaler to one image; this can be an image size factor,
|
||||||
which scales the image while maintaining the aspect ratio"""
|
which scales the image while maintaining the aspect ratio"""
|
||||||
|
|||||||
@ -100,6 +100,9 @@ def get_parametrized_options(
|
|||||||
# num_frames is video only
|
# num_frames is video only
|
||||||
if test_type == VLMTestType.VIDEO:
|
if test_type == VLMTestType.VIDEO:
|
||||||
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
|
iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames)
|
||||||
|
iter_kwargs["needs_video_metadata"] = ensure_wrapped(
|
||||||
|
test_info.needs_video_metadata
|
||||||
|
)
|
||||||
|
|
||||||
# No sizes passed for custom inputs, since inputs are directly provided
|
# No sizes passed for custom inputs, since inputs are directly provided
|
||||||
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
|
if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO):
|
||||||
|
|||||||
@ -905,6 +905,54 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def qwen3_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
|
"""Patches and returns an instance of the HfRunner to use for GLM4.1V."""
|
||||||
|
hf_processor = hf_model.processor
|
||||||
|
|
||||||
|
def processor(*args, videos=None, **kwargs):
|
||||||
|
if videos is not None and is_list_of(videos, tuple):
|
||||||
|
# batched multi videos
|
||||||
|
do_sample_frames = {video[1]["do_sample_frames"] for video in videos}
|
||||||
|
assert len(do_sample_frames) == 1
|
||||||
|
if kwargs.get("do_sample_frames") is None:
|
||||||
|
kwargs["do_sample_frames"] = do_sample_frames
|
||||||
|
video_metadata = [
|
||||||
|
[
|
||||||
|
VideoMetadata(
|
||||||
|
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for video in videos
|
||||||
|
]
|
||||||
|
videos = [[video[0]] for video in videos]
|
||||||
|
elif videos is not None and isinstance(videos, tuple):
|
||||||
|
# single video
|
||||||
|
do_sample_frames = videos[1]["do_sample_frames"]
|
||||||
|
if kwargs.get("do_sample_frames") is None:
|
||||||
|
kwargs["do_sample_frames"] = do_sample_frames
|
||||||
|
video_metadata = [
|
||||||
|
[
|
||||||
|
VideoMetadata(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in videos[1].items()
|
||||||
|
if k != "do_sample_frames"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
]
|
||||||
|
videos = [[videos[0]]]
|
||||||
|
else:
|
||||||
|
video_metadata = None
|
||||||
|
|
||||||
|
return hf_processor(
|
||||||
|
*args, videos=videos, video_metadata=video_metadata, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_model.processor = processor
|
||||||
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
from vllm.model_executor.models.tarsier import get_vision_encoder_info
|
from vllm.model_executor.models.tarsier import get_vision_encoder_info
|
||||||
|
|
||||||
|
|||||||
@ -117,6 +117,7 @@ def run_video_test(
|
|||||||
video_assets,
|
video_assets,
|
||||||
test_case.size_wrapper,
|
test_case.size_wrapper,
|
||||||
test_case.num_video_frames,
|
test_case.num_video_frames,
|
||||||
|
test_case.needs_video_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
core.run_test(
|
core.run_test(
|
||||||
|
|||||||
@ -154,7 +154,8 @@ class VLMTestInfo(NamedTuple):
|
|||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
distributed_executor_backend: str | None = None
|
distributed_executor_backend: str | None = None
|
||||||
# Only expanded in video tests
|
# Only expanded in video tests
|
||||||
num_video_frames: int = 16
|
num_video_frames: int | tuple[int] = 16
|
||||||
|
needs_video_metadata: bool = False
|
||||||
|
|
||||||
# Fixed image sizes / image size factors; most tests use image_size_factors
|
# Fixed image sizes / image size factors; most tests use image_size_factors
|
||||||
# The values provided for these two fields will be stacked and expanded
|
# The values provided for these two fields will be stacked and expanded
|
||||||
@ -212,5 +213,6 @@ class ExpandableVLMTestArgs(NamedTuple):
|
|||||||
size_wrapper: ImageSizeWrapper | None = None
|
size_wrapper: ImageSizeWrapper | None = None
|
||||||
# Video only
|
# Video only
|
||||||
num_video_frames: int | None = None
|
num_video_frames: int | None = None
|
||||||
|
needs_video_metadata: bool = False
|
||||||
# Custom inputs only
|
# Custom inputs only
|
||||||
custom_test_opts: CustomTestOptions | None = None
|
custom_test_opts: CustomTestOptions | None = None
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
|||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"total_num_frames": num_frames,
|
"total_num_frames": num_frames,
|
||||||
"fps": fps,
|
"fps": duration / num_frames,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"video_backend": "opencv",
|
"video_backend": "opencv",
|
||||||
"frames_indices": list(range(num_frames)),
|
"frames_indices": list(range(num_frames)),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user