Shanshan Shen 87b4d1557d
[CustomOp][MM] Extract MMEncoderAttention as CustomOp and replace the backend of QwenVisionAttention with it. (#30125)
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-12-15 11:13:32 +08:00

2097 lines
79 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial
from itertools import islice
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig,
Qwen3VLVisionConfig,
)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize,
)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.compilation.decorators import support_torch_compile
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
compute_retained_tokens_count,
compute_retention_mask,
recompute_mrope_positions,
)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.collection_utils import is_list_of
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsMultiModalPruning,
SupportsPP,
_require_is_multimodal,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs,
)
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import (
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576
class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
class Qwen3_VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel,
)
self.linear_fc2 = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel,
)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return mlp_output
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
multimodal_config: MultiModalConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
)
x = x + self.mlp(self.norm2(x))
return x
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
if self.use_postshuffle_norm:
context_dim = self.hidden_size
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(context_dim)
self.linear_fc1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel,
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.num_position_embeddings = vision_config.num_position_embeddings
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
1 + len(self.deepstack_visual_indexes)
)
self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.merger",
)
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend if multimodal_config else None
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
]
)
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
@staticmethod
@lru_cache(maxsize=1024)
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
h_div = h // spatial_merge_size
w_div = w // spatial_merge_size
hpos_ids = hpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
wpos_ids = wpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
def rot_pos_emb(self, grid_thw: list[list[int]]):
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
pos_ids = [
self.rot_pos_ids(h, w, self.spatial_merge_size)
if t == 1
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(
0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
)
w_idxs = torch.linspace(
0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - w01
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side
indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype)
embeds = self.pos_embed(indices)
embeds *= weights
combined = embeds.sum(dim=0)
combined = combined.reshape(
h // m_size, m_size, w // m_size, m_size, hidden_dim
)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)
hidden_states = hidden_states.unsqueeze(1)
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
hidden_states
)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists, dim=1
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3VLConfig)
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
return self.ctx.get_hf_processor(
Qwen3VLProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor
def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
) -> tuple[ImageSize, int]:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size,
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
**extra_kwargs,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width, height=image_height)
padded_num_frames = num_frames + num_frames % temporal_patch_size
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(
max_tokens, start_num_frames=start_num_frames
)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO
)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
):
if not isinstance(indices, list):
indices = indices.tolist()
if len(indices) % merge_size != 0:
# don't update metadata's frames_indices directly
indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
timestamps = [idx / video_fps for idx in indices]
timestamps = [
(timestamps[i] + timestamps[i + merge_size - 1]) / 2
for i in range(0, len(timestamps), merge_size)
]
return timestamps
def _get_video_second_idx(
self,
metadata: dict[str, Any],
out_item: MultiModalKwargsItem,
do_sample_frames: bool | None = None,
sampled_fps: float | None = None,
) -> list[int]:
video_processor = self.get_video_processor()
merge_size = video_processor.merge_size
indices = metadata["frames_indices"]
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
if do_sample_frames is None:
do_sample_frames = metadata.get("do_sample_frames", False)
# If video frames are sampled in HF processor (instead of vLLM
# video loader), we need to re-calculate the indices from original
# metadata.
if do_sample_frames:
# here video_fps is the fps of the sampled video, and
# metadata["fps"] refers to the fps of the original video.
sampled_fps = sampled_fps if sampled_fps else video_processor.fps
total_num_frames = metadata["total_num_frames"]
num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
num_frames = min(
min(
max(num_frames, video_processor.min_frames),
video_processor.max_frames,
),
total_num_frames,
)
indices = (
np.linspace(0, total_num_frames - 1, num_frames)
.round()
.astype(int)
.tolist()
)
timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
return timestamps
class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
video_token = "<|vision_start|><|video_pad|><|vision_end|>"
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
target_width, target_height = self.info.get_image_size_with_most_features()
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
num_frames_override = video_overrides.num_frames
if num_frames_override:
if num_frames_override > target_num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
num_frames_override,
target_num_frames,
)
if num_frames_override < 2:
logger.warning(
"video.num_frames override (%d) cannot be less "
"than 2, will be ignored",
num_frames_override,
)
target_num_frames = min(target_num_frames, num_frames_override)
target_num_frames = max(target_num_frames, 2)
target_video_size, _ = self.info._get_vision_info(
image_width=target_width,
image_height=target_height,
num_frames=target_num_frames,
image_processor=self.info.get_video_processor(),
)
# NOTE: we need to do this check here since Qwen3-VL resizes video
# frames depending on how many frames there are.
width, height = target_video_size.width, target_video_size.height
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
width_override = video_overrides.width
if width_override:
if width_override > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
width_override,
width,
)
width = min(width, width_override)
height_override = video_overrides.height
if height_override:
if height_override > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
height_override,
height,
)
height = min(height, height_override)
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
"video": self._get_dummy_videos(
width=width,
height=height,
num_frames=target_num_frames,
num_videos=num_videos,
),
}
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[VideoItem]:
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = []
for i in range(num_videos):
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser(
self.info.get_hf_config().vision_config.spatial_merge_size,
video_needs_metadata=True,
)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
# Separate video processing from image processing. Because the videos
# are processed into several image patches
if videos := mm_data.pop("videos", []):
video_grid_thw_lst = []
pixel_values_videos_lst = []
for item in videos:
video_array, metadata = item
# NOTE: @JJJYmmm new attr metadata.frames_indices indicates
# the sampled frames indices of pre-sampled videos, which is
# used to calculate the timestamps. Make sure that
# do_sample_frames in mm_kwargs is false for presampled videos.
# NOTE: a copy of is created to update do_sample_frames,
# otherwise mm_hash for the object will be incorrect.
video_mm_kwargs = dict(**mm_kwargs)
if "do_sample_frames" not in video_mm_kwargs:
# qwen_vl_utils already has "do_sample_frames" in
# mm_kwargs, don't overwrite it.
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", False
)
metadata = VideoMetadata(
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_outputs = super()._call_hf_processor(
prompt="<|vision_start|><|video_pad|><|vision_end|>",
mm_data=video_mm_data,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = video_outputs.pop("input_ids")
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
"<|vision_start|><|video_pad|><|vision_end|>",
video_placeholder,
1,
)
video_grid_thw_lst.append(video_outputs["video_grid_thw"])
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_grid_thw=torch.cat(video_grid_thw_lst),
)
else:
video_outputs = dict()
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
combined_outputs = dict(
processed_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
vision_end_token_id = hf_config.vision_end_token_id
merge_length = image_processor.merge_size**2
def get_image_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx]
grid_thw = out_item["image_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [hf_processor.image_token_id] * num_tokens
def get_video_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["video"][item_idx]
grid_thw = out_item["video_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
sampled_fps = hf_processor_mm_kwargs.get("fps")
if is_list_of(sampled_fps, float):
sampled_fps = sampled_fps[item_idx]
timestamps = self.info._get_video_second_idx(
metadata, out_item, do_sample_frames, sampled_fps
)
assert len(timestamps) == grid_thw[0], (
f"The timestamps length({len(timestamps)}) should be equal "
f"video length ({grid_thw[0]})."
)
frames_idx_token = [
tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False)
for curr_time in timestamps
]
tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token]
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
total_retained = compute_retained_tokens_count(
tokens_per_frame,
len(frames_idx_token),
video_pruning_rate,
)
if len(frames_idx_token) == 0:
per_frame_token_counts = []
elif len(frames_idx_token) == 1:
per_frame_token_counts = [tokens_per_frame]
else:
first_frame_tokens = tokens_per_frame
remaining_tokens = max(total_retained - first_frame_tokens, 0)
base = remaining_tokens // (len(frames_idx_token) - 1)
remainder = remaining_tokens % (len(frames_idx_token) - 1)
per_frame_token_counts = [first_frame_tokens]
for frame_idx in range(1, len(frames_idx_token)):
extra = base + (1 if (frame_idx - 1) < remainder else 0)
per_frame_token_counts.append(extra)
placeholder = []
for frame_idx, timestamp_tokens in enumerate(frames_idx_token):
placeholder.extend(timestamp_tokens)
tokens_this_frame = per_frame_token_counts[
frame_idx if frame_idx < len(per_frame_token_counts) else -1
]
placeholder.extend(
[vision_start_token_id]
+ [video_token_id] * tokens_this_frame
+ [vision_end_token_id]
)
return PromptUpdateDetails.select_token_id(placeholder, video_token_id)
return [
PromptReplacement(
modality="image",
target=hf_processor.image_token,
replacement=get_image_replacement_qwen3vl,
),
# NOTE: We match string on purpose since searching sequence of
# token ids takes more time.
PromptReplacement(
modality="video",
target="<|vision_start|><|video_pad|><|vision_end|>",
replacement=get_video_replacement_qwen3vl,
),
]
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
# the same shape as input_embeds
"deepstack_input_embeds": 0,
}
)
class Qwen3LLMModel(Qwen3Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
if not get_pp_group().is_first_rank:
assert self.start_layer >= len(
vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes
), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
# args for deepstack
deepstack_input_embeds: IntermediateTensors | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and layer_idx in range(
0, len(deepstack_input_embeds)
):
hidden_states = (
hidden_states
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3ForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3LLMModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="lm_head",
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
SupportsEagle3,
SupportsMultiModalPruning,
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if modality.startswith("video"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
super().__init__()
config: Qwen3VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
if not multimodal_config.get_limit_per_prompt(
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
# register buffer for deepstack
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
:num_tokens
]
for idx in range(self.deepstack_num_level)
}
)
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
self.deepstack_input_embeds = [
torch.zeros(
num_tokens,
self.config.text_config.hidden_size,
device=self.deepstack_input_embeds[0].device,
dtype=self.deepstack_input_embeds[0].dtype,
)
for _ in range(self.deepstack_num_level)
]
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx]
)
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
return Qwen2_5_VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Qwen2_5_VLVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
)
if video_embeds is not None:
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
def _process_image_input(
self, image_input: Qwen2_5_VLImageInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)
def _process_video_input(
self, video_input: Qwen2_5_VLVideoInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype
)
if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _postprocess_image_embeds_evs(
self,
image_embeds_split: tuple[torch.Tensor, ...],
image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
"""
Append mrope positions for each for images.
This is necessary to recover correct mrope
positions after video pruning
Args:
image_embeds_split: Tuple of image embeddings for
each image item.
image_input: Image input data.
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = image_embeds_out
return tuple(image_embeds_split)
def _postprocess_video_embeds_evs(
self,
video_embeds_split: tuple[torch.Tensor, ...],
video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
"""
Prunes video embeddings via Efficient Video Sampling (EVS)
and then appends mrope positions for each retained embeddings
Args:
video_embeds_split: Tuple of video embeddings for each video item.
video_input: Video input data.
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 4 channels for
computed mrope positions.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Cast to long to match the original code
# https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
second_per_grid_ts = video_input.get("second_per_grid_ts")
if second_per_grid_ts is None:
# For Qwen3-VL, second_per_grid_ts might not be available
# Use default value of 1.0 for each video
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
else:
second_per_grid_ts = second_per_grid_ts.long()
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
video_embeds_out = []
for emb, size, video_second_per_grid_t in zip(
video_embeds_split, grid_thw_list, second_per_grid_ts
):
# For each video, we compute retention mask using EVS
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Debug logging for EVS pruning
logger.debug(
"EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, "
"pruning_rate=%.2f, reduction=%.1f%%)",
emb.shape[0],
retention_mask.sum().item(),
size[0],
size[1],
size[2],
self.video_pruning_rate,
(1 - retention_mask.float().mean().item()) * 100,
)
positions = compute_mrope_for_media(
size,
merge_size,
tokens_per_second=tokens_per_second,
video_second_per_grid=video_second_per_grid_t.item(),
).to(emb.device)
emb = emb[retention_mask]
positions = positions[retention_mask]
emb = torch.cat([emb, positions], dim=1)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
for input_key in kwargs:
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key in ("pixel_values_videos", "video_embeds")
and "video" not in mm_input_by_modality
):
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
**kwargs
)
return mm_input_by_modality
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int]]:
"""
Iterate over multimodal features and yield grid information.
For videos with EVS (Efficient Video Sampling) enabled, this function
computes the offset based on the pruned token count rather than relying
on input_tokens.index(), which would fail when tokens are pruned.
Args:
input_tokens: List of token IDs in the prompt
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_h, grid_w) for each frame/image
"""
video_token_id = self.config.video_token_id
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, h // spatial_merge_size, w // spatial_merge_size
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
# Check if EVS (Efficient Video Sampling) is enabled
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
frame_offsets = self._extract_frame_offsets_from_mask(
mm_feature.mm_position, t
)
if frame_offsets is not None:
for rel_offset in frame_offsets:
yield offset + rel_offset, llm_grid_h, llm_grid_w
continue
# If EVS is enabled but mask is missing, this indicates a bug
# in the prompt processing pipeline. The is_embed mask should
# always be present when video_pruning_rate > 0.
raise RuntimeError(
f"EVS is enabled (pruning_rate={self.video_pruning_rate}) "
"but is_embed mask is missing from mm_position. "
"This indicates a bug in prompt processing."
)
else:
# Non-EVS mode: Use original logic with input_tokens.index()
for _ in range(t):
offset = input_tokens.index(video_token_id, offset)
yield offset, llm_grid_h, llm_grid_w
offset += llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def _get_evs_mask_segments(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[torch.Tensor] | None:
"""Extract contiguous segments from EVS is_embed mask.
The EVS (Efficient Video Sampling) mask marks which placeholder
positions should be filled with video embeddings. This method splits
the mask into contiguous segments, where each segment represents one
retained frame.
This is a pure function - it does not modify any state and always
returns the same output for the same input (idempotent).
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frame segments
Returns:
List of tensors, each containing indices for one frame segment,
or None if EVS is not enabled or validation fails.
"""
is_embed_mask = getattr(mm_position, "is_embed", None)
if is_embed_mask is None:
return None
# Find all True positions in the mask
mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1)
true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten()
if true_indices.numel() == 0:
return None
# Split into contiguous segments (where diff > 1 indicates a gap)
if true_indices.numel() == 1:
segments = [true_indices]
else:
diffs = torch.diff(true_indices)
split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten()
if split_points.numel() == 0:
segments = [true_indices]
else:
segments = torch.tensor_split(
true_indices, split_points.add(1).tolist()
)
# Validate segment count matches expected frames
if len(segments) < expected_frames:
logger.debug(
"EVS mask segments (%d) do not match expected frames (%d)",
len(segments),
expected_frames,
)
return None
return segments[:expected_frames]
def _extract_frame_offsets_from_mask(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return relative offsets for each EVS-retained frame.
The prompt processor stores a boolean mask inside ``mm_position`` that
marks which placeholder locations should be populated with video
embeddings. By splitting that mask into contiguous runs we can recover
the start of every retained frame without probing ``input_tokens``.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of starting offsets (relative to mm_position) for each frame,
or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [int(segment[0].item()) for segment in segments]
def _get_actual_frame_token_counts(
self, mm_position: PlaceholderRange, expected_frames: int
) -> list[int] | None:
"""Return actual token count for each EVS-retained frame.
This function calculates the actual number of tokens per frame by
analyzing the is_embed mask, accounting for EVS pruning. Each frame
may have a different token count due to content-aware pruning.
Args:
mm_position: MultiModal position containing the is_embed mask
expected_frames: Expected number of frames
Returns:
List of token counts for each frame, or None if EVS is not enabled.
"""
segments = self._get_evs_mask_segments(mm_position, expected_frames)
if segments is None:
return None
return [len(seg) for seg in segments]
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: tuple[torch.Tensor, ...],
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt (Containing
entire sequence).
multimodal_embeddings: Tuple of multimodal embeddings.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
# Device
device = (
multimodal_embeddings[0].device
if len(multimodal_embeddings)
else mrope_positions.device
)
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
mm_embeddings_pos = [
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
]
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
mm_embeddings_pos,
mrope_positions,
num_computed_tokens,
vision_start_token_id,
image_token_id,
video_token_id,
)
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
# Pre-collect actual frame token counts for EVS mode
frame_token_counts_map = {}
for mm_feature in mm_features:
if mm_feature.modality == "video":
is_evs_enabled = (
hasattr(self, "video_pruning_rate")
and self.video_pruning_rate is not None
and self.video_pruning_rate > 0.0
)
if is_evs_enabled:
t = mm_feature.data["video_grid_thw"].data.tolist()[0]
token_counts = self._get_actual_frame_token_counts(
mm_feature.mm_position, t
)
assert token_counts is not None, (
"EVS enabled but failed to extract frame token counts "
"from is_embed mask"
)
frame_token_counts_map[mm_feature.mm_position.offset] = token_counts
llm_pos_ids_list = []
st = 0
frame_counts_idx = {}
for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
input_tokens, mm_features
):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
# Determine actual token count for this frame
base_offset = None
for feat_offset in frame_token_counts_map:
if offset >= feat_offset:
base_offset = feat_offset
if base_offset is not None:
# EVS mode: use actual token count from is_embed mask
assert base_offset in frame_token_counts_map, (
f"Found base_offset {base_offset} but not in frame_token_counts_map"
)
if base_offset not in frame_counts_idx:
frame_counts_idx[base_offset] = 0
counts = frame_token_counts_map[base_offset]
idx = frame_counts_idx[base_offset]
assert idx < len(counts), (
f"EVS frame index {idx} out of range (total frames: {len(counts)})"
)
actual_frame_tokens = counts[idx]
frame_counts_idx[base_offset] += 1
else:
# Non-EVS mode (or image): use theoretical grid size
actual_frame_tokens = llm_grid_h * llm_grid_w
# Add text segment
text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(text_positions)
st_idx += text_len
# Add frame segment with actual token count (not theoretical)
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
# Only take the first actual_frame_tokens positions
frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx
llm_pos_ids_list.append(frame_positions)
# Update st using actual token count
st = offset + actual_frame_tokens
# Handle final text segment
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
final_text_positions = (
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_pos_ids_list.append(final_text_positions)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings += tuple(image_embeddings)
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)
multimodal_embeddings = torch.split(
multimodal_embeddings_main, visual_lens, dim=0
)
multimodal_embeddings_multiscale = torch.split(
multimodal_embeddings_multiscale, visual_lens, dim=0
)
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
)
deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
return deepstack_input_embeds, multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for Qwen3VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen3VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
intermediate_tensors: Intermediate tensors from previous pipeline
stages.
inputs_embeds: Pre-computed input embeddings.
**kwargs: Additional keyword arguments including:
- pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
- image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
LLM. `None` if no images are passed.
- pixel_values_videos: Pixel values of videos to be fed to a
model. `None` if no videos are passed.
- video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
LLM. `None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)
else:
deepstack_input_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
# args for deepstack
deepstack_input_embeds=deepstack_input_embeds,
)
if inputs_embeds is not None and get_pp_group().is_first_rank:
self._clear_deepstack_input_embeds(inputs_embeds.size(0))
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.merger",
tower_model="visual.",
)