diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0248700292ae2..db650b37a38db 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -553,6 +553,7 @@ Specified using `--task generate`. | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | @@ -583,7 +584,7 @@ Specified using `--task generate`. | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 57b042ed013b1..b9e8bef26eb24 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -248,6 +248,42 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.1V +def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "THUDM/GLM-4.1V-9B-Thinking" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1114,6 +1150,7 @@ model_example_map = { "fuyu": run_fuyu, "gemma3": run_gemma3, "glm4v": run_glm4v, + "glm4_1v": run_glm4_1v, "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, @@ -1172,10 +1209,11 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays + metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": video, + "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, "questions": vid_questions, } diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 990ea3579291d..b68e08556ee96 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -50,7 +50,7 @@ async def client(server): @pytest.fixture(scope="session") def base64_encoded_video() -> dict[str, str]: return { - video_url: encode_video_base64(fetch_video(video_url)) + video_url: encode_video_base64(fetch_video(video_url)[0]) for video_url in TEST_VIDEO_URLS } diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 9d63339737ce6..6ecf6db56cb39 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -309,6 +309,34 @@ VLM_TEST_SETTINGS = { num_logprobs=10, marks=[large_gpu_mark(min_gb=32)], ), + "glm4_1v": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501 + max_model_len=2048, + max_num_seqs=2, + get_stop_token_ids=lambda tok: [151329, 151336, 151338], + num_logprobs=10, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + auto_cls=AutoModelForImageTextToText, + ), + "glm4_1v-video": VLMTestInfo( + models=["THUDM/GLM-4.1V-9B-Thinking"], + # GLM4.1V require include video metadata for input + test_type=VLMTestType.CUSTOM_INPUTS, + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + patch_hf_runner=model_utils.glm4_1v_patch_hf_runner, + custom_test_opts=[CustomTestOptions( + inputs=custom_inputs.video_with_metadata_glm4_1v(), + limit_mm_per_prompt={"video": 1}, + )], + # This is needed to run on machine with 24GB VRAM + vllm_runner_kwargs={"gpu_memory_utilization": 0.95}, + ), "h2ovl": VLMTestInfo( models = [ "h2oai/h2ovl-mississippi-800m", diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index aa5835243e042..c53243b42e384 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -129,3 +129,23 @@ def windows_attention_image_qwen2_5_vl(): wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) + + +def video_with_metadata_glm4_1v(): + video_array = VIDEO_ASSETS[0].np_ndarrays + metadata = VIDEO_ASSETS[0].metadata + question = "Describe the video." + video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" + formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + + scales = [0.1, 0.2, 0.25] + video_input = [[(rescale_video_size(video_array, scale), metadata)] + for scale in scales] + prompts = [formatted_prompt] * len(video_input) + + return [ + PromptWithMultiModalInput( + prompts=prompts, + video_data=video_input, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index af4c72f44b676..c1a2aa0dcafbb 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -16,9 +16,11 @@ import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig, GenerationMixin) +from transformers.video_utils import VideoMetadata from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side +from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -373,6 +375,28 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model +def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for GLM4.1V.""" + hf_processor = hf_model.processor + + def processor(*args, videos=None, **kwargs): + if videos is not None and is_list_of(videos, tuple): + # If videos is a list of tuples, we assume each tuple contains + # (video_array, metadata) as in the case of GLM4.1V. + video_metadata = [[VideoMetadata(**video[1])] for video in videos] + videos = [[video[0]] for video in videos] + else: + video_metadata = None + + return hf_processor(*args, + videos=videos, + video_metadata=video_metadata, + **kwargs) + + hf_model.processor = processor + return hf_model + + def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for H2OVL.""" diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 1ba60178c13db..0f33225eda2da 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -24,6 +24,22 @@ from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS +def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for GLM4.1V model. + """ + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + mm_data["video"] = (video, { + "total_num_frames": len(video), + "fps": len(video), + "duration": 1, + "video_backend": "opencv" + }) + return mm_data + + def _test_processing_correctness( model_id: str, hit_rate: float, @@ -154,6 +170,11 @@ _IGNORE_MM_KEYS = { "ultravox": {"audio_features"}, } +MM_DATA_PATCHES = { + # GLM4.1V requires video metadata to be included in the input + "glm4v": glm4_1v_patch_mm_data, +} + def _test_processing_correctness_one( model_config: ModelConfig, @@ -166,6 +187,8 @@ def _test_processing_correctness_one( ): model_type = model_config.hf_config.model_type ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) + if model_type in MM_DATA_PATCHES: + mm_data = MM_DATA_PATCHES[model_type](mm_data) if isinstance(prompt, str): text_prompt = prompt @@ -245,6 +268,7 @@ def _test_processing_correctness_one( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", + "THUDM/GLM-4.1V-9B-Thinking", "ibm-granite/granite-speech-3.3-2b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", diff --git a/tests/models/registry.py b/tests/models/registry.py index e56dd19bec670..affe2e88b2b94 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -338,6 +338,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 + "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 5ac0a90f50473..a48542cec3f87 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -172,7 +172,9 @@ async def test_fetch_video_http(video_url: str, num_frames: int): video_sync = connector.fetch_video(video_url, num_frames=num_frames) video_async = await connector.fetch_video_async(video_url, num_frames=num_frames) - assert np.array_equal(video_sync, video_async) + # Check that the video frames are equal and metadata are same + assert np.array_equal(video_sync[0], video_async[0]) + assert video_sync[1] == video_async[1] # Used for the next two tests related to `merge_and_sort_multimodal_metadata`. diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 01834aeeb6c12..16412121cf0a8 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal, Optional import cv2 import numpy as np @@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str, ] +def video_get_metadata(path: str) -> dict[str, Any]: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames / fps if fps > 0 else 0 + + metadata = { + "total_num_frames": total_frames, + "fps": fps, + "duration": duration, + "video_backend": "opencv" + } + return metadata + + VideoAssetName = Literal["baby_reading"] @@ -105,6 +123,12 @@ class VideoAsset: ret = video_to_ndarrays(video_path, self.num_frames) return ret + @property + def metadata(self) -> dict[str, Any]: + video_path = download_video_asset(self.filename) + ret = video_get_metadata(video_path) + return ret + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 35ee52ab4601d..45f1894d022b3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -515,6 +515,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" + if model_type == "glm4v": + return "<|begin_of_image|><|image|><|end_of_image|>" if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" if model_type in ("minicpmo", "minicpmv"): @@ -563,6 +565,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): elif modality == "video": if model_type == "internvl_chat": return "