diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 0248700292ae2..db650b37a38db 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -553,6 +553,7 @@ Specified using `--task generate`.
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ |
@@ -583,7 +584,7 @@ Specified using `--task generate`.
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
-| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
+| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |
^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
• For example, to use DeepSeek-VL2 series models:
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 57b042ed013b1..b9e8bef26eb24 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -248,6 +248,42 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
)
+# GLM-4.1V
+def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "THUDM/GLM-4.1V-9B-Thinking"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=2,
+ mm_processor_kwargs={
+ "size": {"shortest_edge": 12544, "longest_edge": 47040000},
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ enforce_eager=True,
+ )
+
+ if modality == "image":
+ placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
+ elif modality == "video":
+ placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
+
+ prompts = [
+ (
+ "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n"
+ f"{placeholder}"
+ f"{question}<|assistant|>assistant\n"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1114,6 +1150,7 @@ model_example_map = {
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"glm4v": run_glm4v,
+ "glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
"internvl_chat": run_internvl,
@@ -1172,10 +1209,11 @@ def get_multi_modal_input(args):
if args.modality == "video":
# Input video and question
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
+ metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata
vid_questions = ["Why is this video funny?"]
return {
- "data": video,
+ "data": [(video, metadata)] if args.model_type == "glm4_1v" else video,
"questions": vid_questions,
}
diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py
index 990ea3579291d..b68e08556ee96 100644
--- a/tests/entrypoints/openai/test_video.py
+++ b/tests/entrypoints/openai/test_video.py
@@ -50,7 +50,7 @@ async def client(server):
@pytest.fixture(scope="session")
def base64_encoded_video() -> dict[str, str]:
return {
- video_url: encode_video_base64(fetch_video(video_url))
+ video_url: encode_video_base64(fetch_video(video_url)[0])
for video_url in TEST_VIDEO_URLS
}
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index 9d63339737ce6..6ecf6db56cb39 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -309,6 +309,34 @@ VLM_TEST_SETTINGS = {
num_logprobs=10,
marks=[large_gpu_mark(min_gb=32)],
),
+ "glm4_1v": VLMTestInfo(
+ models=["THUDM/GLM-4.1V-9B-Thinking"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501
+ img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501
+ video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501
+ max_model_len=2048,
+ max_num_seqs=2,
+ get_stop_token_ids=lambda tok: [151329, 151336, 151338],
+ num_logprobs=10,
+ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
+ auto_cls=AutoModelForImageTextToText,
+ ),
+ "glm4_1v-video": VLMTestInfo(
+ models=["THUDM/GLM-4.1V-9B-Thinking"],
+ # GLM4.1V require include video metadata for input
+ test_type=VLMTestType.CUSTOM_INPUTS,
+ max_model_len=4096,
+ max_num_seqs=2,
+ auto_cls=AutoModelForImageTextToText,
+ patch_hf_runner=model_utils.glm4_1v_patch_hf_runner,
+ custom_test_opts=[CustomTestOptions(
+ inputs=custom_inputs.video_with_metadata_glm4_1v(),
+ limit_mm_per_prompt={"video": 1},
+ )],
+ # This is needed to run on machine with 24GB VRAM
+ vllm_runner_kwargs={"gpu_memory_utilization": 0.95},
+ ),
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py
index aa5835243e042..c53243b42e384 100644
--- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py
+++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py
@@ -129,3 +129,23 @@ def windows_attention_image_qwen2_5_vl():
wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5])
return build_single_image_inputs([image], [prompt], wrapped_sf)
+
+
+def video_with_metadata_glm4_1v():
+ video_array = VIDEO_ASSETS[0].np_ndarrays
+ metadata = VIDEO_ASSETS[0].metadata
+ question = "Describe the video."
+ video_prompt = "<|begin_of_video|><|video|><|end_of_video|>"
+ formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n"
+
+ scales = [0.1, 0.2, 0.25]
+ video_input = [[(rescale_video_size(video_array, scale), metadata)]
+ for scale in scales]
+ prompts = [formatted_prompt] * len(video_input)
+
+ return [
+ PromptWithMultiModalInput(
+ prompts=prompts,
+ video_data=video_input,
+ )
+ ]
diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py
index af4c72f44b676..c1a2aa0dcafbb 100644
--- a/tests/models/multimodal/generation/vlm_utils/model_utils.py
+++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py
@@ -16,9 +16,11 @@ import torch
from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
GenerationConfig, GenerationMixin)
+from transformers.video_utils import VideoMetadata
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
+from vllm.utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
from .types import RunnerOutput
@@ -373,6 +375,28 @@ def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
+def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+ """Patches and returns an instance of the HfRunner to use for GLM4.1V."""
+ hf_processor = hf_model.processor
+
+ def processor(*args, videos=None, **kwargs):
+ if videos is not None and is_list_of(videos, tuple):
+ # If videos is a list of tuples, we assume each tuple contains
+ # (video_array, metadata) as in the case of GLM4.1V.
+ video_metadata = [[VideoMetadata(**video[1])] for video in videos]
+ videos = [[video[0]] for video in videos]
+ else:
+ video_metadata = None
+
+ return hf_processor(*args,
+ videos=videos,
+ video_metadata=video_metadata,
+ **kwargs)
+
+ hf_model.processor = processor
+ return hf_model
+
+
def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for H2OVL."""
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 1ba60178c13db..0f33225eda2da 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -24,6 +24,22 @@ from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
+def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
+ """
+ Patch the multimodal data for GLM4.1V model.
+ """
+ # Ensure video metadata is included
+ if "video" in mm_data:
+ video = mm_data["video"]
+ mm_data["video"] = (video, {
+ "total_num_frames": len(video),
+ "fps": len(video),
+ "duration": 1,
+ "video_backend": "opencv"
+ })
+ return mm_data
+
+
def _test_processing_correctness(
model_id: str,
hit_rate: float,
@@ -154,6 +170,11 @@ _IGNORE_MM_KEYS = {
"ultravox": {"audio_features"},
}
+MM_DATA_PATCHES = {
+ # GLM4.1V requires video metadata to be included in the input
+ "glm4v": glm4_1v_patch_mm_data,
+}
+
def _test_processing_correctness_one(
model_config: ModelConfig,
@@ -166,6 +187,8 @@ def _test_processing_correctness_one(
):
model_type = model_config.hf_config.model_type
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
+ if model_type in MM_DATA_PATCHES:
+ mm_data = MM_DATA_PATCHES[model_type](mm_data)
if isinstance(prompt, str):
text_prompt = prompt
@@ -245,6 +268,7 @@ def _test_processing_correctness_one(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
+ "THUDM/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index e56dd19bec670..affe2e88b2b94 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -338,6 +338,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
+ "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py
index 5ac0a90f50473..a48542cec3f87 100644
--- a/tests/multimodal/test_utils.py
+++ b/tests/multimodal/test_utils.py
@@ -172,7 +172,9 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
video_sync = connector.fetch_video(video_url, num_frames=num_frames)
video_async = await connector.fetch_video_async(video_url,
num_frames=num_frames)
- assert np.array_equal(video_sync, video_async)
+ # Check that the video frames are equal and metadata are same
+ assert np.array_equal(video_sync[0], video_async[0])
+ assert video_sync[1] == video_async[1]
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
diff --git a/vllm/assets/video.py b/vllm/assets/video.py
index 01834aeeb6c12..16412121cf0a8 100644
--- a/vllm/assets/video.py
+++ b/vllm/assets/video.py
@@ -3,7 +3,7 @@
from dataclasses import dataclass
from functools import lru_cache
-from typing import ClassVar, Literal, Optional
+from typing import Any, ClassVar, Literal, Optional
import cv2
import numpy as np
@@ -77,6 +77,24 @@ def video_to_pil_images_list(path: str,
]
+def video_get_metadata(path: str) -> dict[str, Any]:
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened():
+ raise ValueError(f"Could not open video file {path}")
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ duration = total_frames / fps if fps > 0 else 0
+
+ metadata = {
+ "total_num_frames": total_frames,
+ "fps": fps,
+ "duration": duration,
+ "video_backend": "opencv"
+ }
+ return metadata
+
+
VideoAssetName = Literal["baby_reading"]
@@ -105,6 +123,12 @@ class VideoAsset:
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
+ @property
+ def metadata(self) -> dict[str, Any]:
+ video_path = download_video_asset(self.filename)
+ ret = video_get_metadata(video_path)
+ return ret
+
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 35ee52ab4601d..45f1894d022b3 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -515,6 +515,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
+ if model_type == "glm4v":
+ return "<|begin_of_image|><|image|><|end_of_image|>"
if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
if model_type in ("minicpmo", "minicpmv"):
@@ -563,6 +565,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video":
if model_type == "internvl_chat":
return "