diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 404519f887dc6..25579835faf63 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
+| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 624de2a2debc3..65ea4df4a3099 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
)
+# HunyuanOCR
+def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "tencent/HunyuanOCR"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+ prompts = [
+ f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=None,
+ )
+
+
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision(
questions: list[str], modality: str
@@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl,
+ "hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 758ec54493aa3..f8b3470e6d39b 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
+ "HunYuanVLForConditionalGeneration": _HfExamplesInfo(
+ "tencent/HunyuanOCR",
+ is_available_online=False,
+ ),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
diff --git a/vllm/config/model.py b/vllm/config/model.py
index c37dd7c15f2a7..caa9a3440c41d 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -33,6 +33,7 @@ from vllm.transformers_utils.config import (
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_mrope,
+ uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
@@ -1615,6 +1616,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
+ @property
+ def uses_xdrope_dim(self) -> int:
+ return uses_xdrope_dim(self.hf_config)
+
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py
index 152d9401b8e94..0f10bff6ac4f5 100644
--- a/vllm/model_executor/layers/rotary_embedding/__init__.py
+++ b/vllm/model_executor/layers/rotary_embedding/__init__.py
@@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
+from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@@ -184,6 +185,18 @@ def get_rope(
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
+ elif scaling_type == "xdrope":
+ scaling_alpha = rope_parameters["alpha"]
+ rotary_emb = XDRotaryEmbedding(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ xdrope_section=rope_parameters["xdrope_section"],
+ )
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py
new file mode 100644
index 0000000000000..2432273faf195
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import numpy as np
+import torch
+
+from .common import apply_rotary_emb_dispatch
+from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
+
+
+class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
+ """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
+
+ Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
+ """
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: float,
+ is_neox_style: bool,
+ scaling_alpha: float,
+ dtype: torch.dtype,
+ xdrope_section: list[int],
+ ) -> None:
+ self.xdrope_section = xdrope_section
+ super().__init__(
+ head_size,
+ rotary_dim,
+ max_position_embeddings,
+ base,
+ is_neox_style,
+ scaling_alpha,
+ dtype,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """PyTorch-native implementation equivalent to forward().
+
+ Args:
+ positions:
+ [4, num_tokens] (P/W/H/T positions with multimodal inputs)
+ query: [num_tokens, num_heads * head_size]
+ key: [num_tokens, num_kv_heads * head_size]
+ """
+ assert positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = torch.cat(
+ [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+ sin = torch.cat(
+ [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
+ )
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
+
+ @staticmethod
+ def get_next_input_positions(
+ context_len: int,
+ seq_len: int,
+ xd_sections: int = 4,
+ ) -> list[list[int]]:
+ return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
+
+ @staticmethod
+ def get_next_input_positions_tensor(
+ out: np.ndarray,
+ out_offset: int,
+ context_len: int,
+ num_new_tokens: int,
+ ):
+ values = np.arange(
+ context_len,
+ context_len + num_new_tokens,
+ dtype=out.dtype,
+ )
+ out[:, out_offset : out_offset + num_new_tokens] = values
diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py
index 9fa5e2bd33f21..53fb444ed622d 100644
--- a/vllm/model_executor/models/hunyuan_v1.py
+++ b/vllm/model_executor/models/hunyuan_v1.py
@@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states
-@support_torch_compile
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py
new file mode 100644
index 0000000000000..e83addd0c092f
--- /dev/null
+++ b/vllm/model_executor/models/hunyuan_vision.py
@@ -0,0 +1,1028 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# coding=utf-8
+# Copyright 2025 The HunYuan team.
+# Copyright 2025 The vLLM team.
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only HunYuan-VL model compatible with HuggingFace weights."""
+
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Any, Literal, TypeAlias
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import BatchFeature
+
+from vllm.attention.backends.registry import AttentionBackendEnum
+from vllm.attention.layer import MultiHeadAttention
+from vllm.config import MultiModalConfig, VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ ImageItem,
+ ModalityData,
+ MultiModalDataDict,
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ DictEmbeddingItems,
+ ImageSize,
+ MultiModalDataItems,
+ MultiModalDataParser,
+)
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLVisionConfig,
+)
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+)
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+# === Vision Inputs === #
+
+
+class HunYuanVLImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: Number of patches
+ - ni: Number of images
+ - cps: Number of channels * patch_size * patch_size
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", "cps"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+class HunYuanVLImageEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - nf: Number of image features
+ - hs: Hidden size
+ - ni: Number of images
+ """
+
+ type: Literal["image_embeds"]
+
+ image_embeds: Annotated[
+ torch.Tensor,
+ TensorShape("nf", "hs"),
+ ]
+
+ image_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+HunYuanVLImageInputs: TypeAlias = (
+ HunYuanVLImagePixelInputs | HunYuanVLImageEmbeddingInputs
+)
+
+# === Vision Encoder === #
+
+
+class HunYuanVisionMLP(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = True,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.dense_h_to_4h = ColumnParallelLinear(
+ in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_h_to_4h",
+ disable_tp=use_data_parallel,
+ )
+ self.dense_4h_to_h = RowParallelLinear(
+ hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense_4h_to_h",
+ disable_tp=use_data_parallel,
+ )
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ x_up, _ = self.dense_h_to_4h(x)
+ x_down, _ = self.dense_4h_to_h(self.act_fn(x_up))
+ return x_down
+
+
+class HunYuanVisionAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = (
+ 1
+ if use_data_parallel
+ else parallel_state.get_tensor_model_parallel_world_size()
+ )
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads
+ )
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size
+ )
+
+ self.qkv = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ disable_tp=use_data_parallel,
+ )
+
+ self.o_proj = RowParallelLinear(
+ input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ disable_tp=use_data_parallel,
+ )
+
+ self.scale = self.hidden_size_per_attention_head**-0.5
+ self.attn = MultiHeadAttention(
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ self.scale,
+ prefix=f"{prefix}.attn",
+ multimodal_config=multimodal_config,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+ out = self.attn(q, k, v)
+ output, _ = self.o_proj(out)
+ return output
+
+
+class HunYuanVisionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
+ norm_layer: Callable[[int], nn.Module] | None = None,
+ quant_config: QuantizationConfig | None = None,
+ multimodal_config: MultiModalConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.input_layernorm = norm_layer(dim)
+ self.post_attention_layernorm = norm_layer(dim)
+ self.self_attn = HunYuanVisionAttention(
+ embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.self_attn",
+ use_data_parallel=use_data_parallel,
+ )
+ self.mlp = HunYuanVisionMLP(
+ dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=use_data_parallel,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ x = x + self.self_attn(self.input_layernorm(x))
+ x = x + self.mlp(self.post_attention_layernorm(x))
+ return x
+
+
+class HunYuanVisionPatchEmbed(nn.Module):
+ def __init__(self, config: HunYuanVLVisionConfig):
+ super().__init__()
+
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.num_channels = config.num_channels
+ self.spatial_merge_size = config.spatial_merge_size
+ self.interpolate_mode = config.interpolate_mode
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=True,
+ )
+
+ self.max_num_patches = (config.max_image_size // self.patch_size) ** 2
+
+ self.num_positions = self.max_num_patches + 1
+ self.position_edge = int(self.num_positions**0.5)
+ # first token is cls token, skip it
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+ self.patch_pos_embed = None
+
+ def forward(
+ self, pixel_values: torch.Tensor, grid_thw: list[list[int]]
+ ) -> torch.Tensor:
+ num_patches = pixel_values.size(0)
+ pixel_values = pixel_values.reshape(
+ num_patches, self.num_channels, self.patch_size, self.patch_size
+ )
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ patch_embeds = patch_embeds.squeeze(-1).squeeze(-1).unsqueeze(0)
+
+ if self.patch_pos_embed is None:
+ patch_pos_shape = (
+ 1,
+ self.position_edge,
+ self.position_edge,
+ self.embed_dim,
+ )
+ self.patch_pos_embed = (
+ self.position_embedding.weight[1:, :]
+ .reshape(patch_pos_shape)
+ .permute(0, 3, 1, 2)
+ .float()
+ )
+
+ patch_pos_embed_list = []
+ for grid in grid_thw:
+ _, h0, w0 = grid
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ self.patch_pos_embed,
+ scale_factor=(h0 / self.position_edge, w0 / self.position_edge),
+ mode=self.interpolate_mode,
+ align_corners=False,
+ )
+
+ patch_pos_embed = (
+ patch_pos_embed.reshape(self.embed_dim, -1)
+ .transpose(0, 1)
+ .unsqueeze(0)
+ .to(patch_embeds.dtype)
+ )
+ patch_pos_embed_list.append(patch_pos_embed)
+
+ patch_pos_embed = torch.cat(patch_pos_embed_list, dim=1)
+ embeddings = patch_embeds + patch_pos_embed
+
+ return embeddings
+
+
+class HunYuanVisionPatchMerger(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ spatial_merge_size=2,
+ rms_norm_eps=1e-5,
+ prefix="",
+ ):
+ super().__init__()
+ self.spatial_merge_size = spatial_merge_size
+ embed_std = out_channels**-0.5
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ in_channels * 2,
+ kernel_size=spatial_merge_size,
+ stride=spatial_merge_size,
+ ),
+ nn.GELU(),
+ nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
+ )
+ self.mlp = nn.Linear(in_channels * 4, out_channels)
+
+ self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std)
+ self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std)
+ self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std)
+
+ self.before_rms = RMSNorm(in_channels, eps=rms_norm_eps)
+ self.after_rms = RMSNorm(out_channels, eps=rms_norm_eps)
+
+ def forward(self, x, size=(16, 16)):
+ x = self.before_rms(x)
+
+ h, w = size
+ dtype = x.dtype
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
+
+ x = self.proj(x) # b,c,h,w
+ b, c, h, w = x.shape
+ x = torch.cat(
+ [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype)],
+ dim=-1,
+ )
+ x = x.reshape(b, c, -1).permute(0, 2, 1)
+ x = self.mlp(x)
+
+ begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype)
+ x = torch.cat([begin, x, end], dim=1)
+
+ return self.after_rms(x)
+
+
+class HunYuanVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ vision_config: HunYuanVLVisionConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ multimodal_config: MultiModalConfig | None = None,
+ attn_backend_override: AttentionBackendEnum | None = None,
+ ) -> None:
+ super().__init__()
+
+ num_hidden_layers = vision_config.num_hidden_layers
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_attention_heads
+ self.spatial_merge_size = vision_config.spatial_merge_size
+
+ from vllm.compilation.backends import set_model_tag
+
+ with set_model_tag("HunYuanVisionPatchEmbed"):
+ self.embeddings = HunYuanVisionPatchEmbed(vision_config)
+
+ norm_layer = partial(nn.LayerNorm, eps=vision_config.rms_norm_eps)
+
+ with set_model_tag("HunYuanVisionBlock"):
+ self.layers = nn.ModuleList(
+ [
+ HunYuanVisionBlock(
+ dim=vision_config.hidden_size,
+ num_heads=vision_config.num_attention_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=get_act_fn(vision_config.hidden_act),
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ multimodal_config=multimodal_config,
+ prefix=f"{prefix}.layers.{layer_idx}",
+ use_data_parallel=use_data_parallel,
+ )
+ for layer_idx in range(num_hidden_layers)
+ ]
+ )
+
+ with set_model_tag("HunYuanVisionPatchMerger"):
+ self.perceive = HunYuanVisionPatchMerger(
+ vision_config.hidden_size,
+ vision_config.out_hidden_size,
+ spatial_merge_size=vision_config.spatial_merge_size,
+ rms_norm_eps=vision_config.rms_norm_eps,
+ prefix=f"{prefix}.perceive",
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.embeddings.patch_embedding.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.embeddings.patch_embedding.weight.device
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ # patchify
+ seq_len = x.size(0)
+ cu_seqlens: list = [0]
+
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.embeddings(hidden_states, grid_thw)
+
+ for t, h, w in grid_thw:
+ t, h, w = int(t), int(h), int(w)
+ cu_seqlens.append(h * w)
+
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
+ cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
+
+ cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
+
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ hidden_states = hidden_states.unsqueeze(0)
+ for layer_num, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states)
+
+ # adapter
+ split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ split_items = hidden_states.split(split_lengths, dim=1)
+ image_embeds_list = []
+ for grid, split_item in zip(grid_thw, split_items):
+ image_embeds_list.append(
+ self.perceive(split_item.contiguous(), size=grid[1:]).squeeze(0)
+ )
+
+ return image_embeds_list
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ (".qkv", ".q_proj", "q"),
+ (".qkv", ".k_proj", "k"),
+ (".qkv", ".v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ )
+
+
+class HunYuanVLMultiModalDataParser(MultiModalDataParser):
+ def _parse_image_data(
+ self,
+ data: dict[str, torch.Tensor] | ModalityData[ImageItem],
+ ):
+ if isinstance(data, dict):
+ return DictEmbeddingItems(
+ data,
+ modality="image",
+ required_fields={"image_embeds", "image_grid_thw"},
+ fields_factory=_hunyuan_vl_field_config,
+ )
+
+ return super()._parse_image_data(data)
+
+
+class HunYuanVLProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(HunYuanVLConfig)
+
+ def get_hf_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.ctx.get_hf_processor(
+ HunYuanVLProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+
+ def get_image_processor(
+ self,
+ **kwargs: object,
+ ) -> HunYuanVLProcessor:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ max_image_tokens = self.get_max_image_tokens()
+ # TODO: support video
+ max_video_tokens = 0
+ return {"image": max_image_tokens, "video": max_video_tokens}
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ spatial_merge_size = vision_config.spatial_merge_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * spatial_merge_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+ grid_t = 1
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_vision_tokens = (
+ grid_t * grid_h // spatial_merge_size * (grid_w // spatial_merge_size + 1)
+ + 2
+ )
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ image_processor: HunYuanVLProcessor | None,
+ ) -> int:
+ _, num_image_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ image_processor=image_processor,
+ )
+ return num_image_tokens
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=512,
+ image_height=8192,
+ image_processor=None,
+ )
+ return max_image_size
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ image_processor=None,
+ )
+
+
+class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+
+ hf_processor = self.info.get_hf_processor()
+ image_token: str = hf_processor.image_token
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 1)
+
+ target_width, target_height = self.info.get_image_size_with_most_features()
+
+ return {
+ "image": self._get_dummy_images(
+ width=target_width, height=target_height, num_images=num_images
+ ),
+ }
+
+
+class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]):
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return HunYuanVLMultiModalDataParser()
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ return self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(text=prompt, **mm_data),
+ dict(**mm_kwargs, **tok_kwargs),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
+
+ placeholder = {
+ "image": hf_processor.image_token_id,
+ }
+
+ merge_size = image_processor.merge_size
+
+ def get_replacement_hunyuan_vl(item_idx: int, modality: str):
+ out_item = out_mm_kwargs[modality][item_idx]
+ grid_thw = out_item[f"{modality}_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+
+ _, grid_h, grid_w = grid_thw
+ num_tokens = (int(grid_h) // merge_size) * (
+ int(grid_w) // merge_size + 1
+ ) + 2
+ return [placeholder[modality]] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=[placeholder[modality]],
+ replacement=partial(get_replacement_hunyuan_vl, modality=modality),
+ )
+ for modality in ("image",)
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _hunyuan_vl_field_config(hf_inputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ HunYuanVLMultiModalProcessor,
+ info=HunYuanVLProcessingInfo,
+ dummy_inputs=HunYuanVLDummyInputsBuilder,
+)
+class HunYuanVLForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsLoRA,
+ SupportsPP,
+ SupportsQuant,
+ SupportsXDRoPE,
+):
+ multimodal_cpu_fields = {"image_grid_thw"}
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers v4.52
+ "vit.vit.": "visual.",
+ "vit.": "visual.",
+ "model.": "language_model.model.",
+ }
+ )
+
+ supports_encoder_tp_data = True
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> torch.Tensor:
+ kwargs = MultiModalFeatureSpec.gather_kwargs(
+ mm_features,
+ {"image_grid_thw"},
+ )
+ image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
+
+ hf_config = self.config
+ image_start_token_id = hf_config.image_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+ xd_num = len(hf_config.rope_scaling["xdrope_section"])
+
+ input_tokens_tensor = torch.tensor(input_tokens)
+ image_start_indices = torch.argwhere(
+ input_tokens_tensor == image_start_token_id
+ ).squeeze(1)
+
+ p_index = torch.arange(len(input_tokens_tensor))
+ w_index = torch.arange(len(input_tokens_tensor))
+ h_index = torch.arange(len(input_tokens_tensor))
+ t_index = torch.arange(len(input_tokens_tensor))
+ for image_index in range(len(image_start_indices)):
+ # +1 : first image_token, +2: for xdrope positions
+ pos = image_start_indices[image_index] + 2
+ t, h, w = image_grid_thw[image_index]
+ _, llm_grid_h, llm_grid_w = (
+ t,
+ h // spatial_merge_size,
+ w // spatial_merge_size,
+ )
+
+ token_num = (llm_grid_w + 1) * llm_grid_h
+ w_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_w + 1)
+ .reshape(1, -1)
+ .expand(llm_grid_h, -1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num].copy_(
+ torch.arange(0, llm_grid_h)
+ .reshape(-1, 1)
+ .expand(-1, llm_grid_w + 1)
+ .reshape(-1)
+ )
+ h_index[pos : pos + token_num] = 0
+
+ if xd_num == 4:
+ llm_positions = torch.stack([p_index, w_index, h_index, t_index])
+ elif xd_num == 3:
+ llm_positions = torch.stack([w_index, h_index, t_index])
+
+ return llm_positions
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: HunYuanVLConfig = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ if multimodal_config.get_limit_per_prompt("image"):
+ attn_backend_override = (
+ multimodal_config.mm_encoder_attn_backend
+ if multimodal_config is not None
+ else None
+ )
+ self.visual = HunYuanVisionTransformer(
+ config.vision_config,
+ quant_config=self.quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ multimodal_config=multimodal_config,
+ attn_backend_override=attn_backend_override,
+ )
+ else:
+ self.visual = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model.model"),
+ architectures=[
+ "HunYuanDenseV1ForCausalLM",
+ "HunYuanMoEV1ForCausalLM",
+ ],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> HunYuanVLImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ # TODO: refine
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat(pixel_values, dim=0)
+ if len(pixel_values.shape) == 3:
+ last_dim = pixel_values.shape[-1]
+ pixel_values = pixel_values.reshape(-1, last_dim)
+ image_grid_thw = image_grid_thw.reshape(-1, 3)
+
+ if pixel_values is not None:
+ return HunYuanVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ )
+
+ if image_embeds is not None:
+ return HunYuanVLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw,
+ )
+
+ def _process_image_input(
+ self, image_input: HunYuanVLImageInputs
+ ) -> tuple[torch.Tensor, ...]:
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"]
+
+ # TODO: use_data_parallel (split image_embeds in visual)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
+
+ return image_embeds
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if (
+ input_key in ("pixel_values", "image_embeds")
+ and "image" not in mm_input_by_modality
+ ):
+ mm_input_by_modality["image"] = self._parse_and_validate_image_input(
+ **kwargs
+ )
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ image_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += tuple(image_embeddings)
+ return multimodal_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model.model",
+ connector="visual.perceive",
+ tower_model="visual",
+ )
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index 9966498e1b4c9..6f6ce32538b71 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
-
+
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
@@ -1088,3 +1088,52 @@ def supports_mrope(
model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE)
+
+
+@runtime_checkable
+class SupportsXDRoPE(Protocol):
+ """The interface required for all models that support XD-RoPE."""
+
+ supports_xdrope: ClassVar[Literal[True]] = True
+ """
+ A flag that indicates this model supports XD-RoPE.
+
+ Note:
+ There is no need to redefine this flag if this class is in the
+ XDRope of your model class.
+ """
+
+ def get_xdrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list["MultiModalFeatureSpec"],
+ ) -> torch.Tensor:
+ """
+ Get XD-RoPE input positions and delta value for this specific model.
+
+ This method should be implemented by each model that supports XD-RoPE
+ to provide model-specific logic for computing input positions.
+
+ Args:
+ input_tokens: List of input token IDs
+ mm_features: Information about each multi-modal data item
+
+ Returns:
+ llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
+ 4D(P/W/H/T) or 3D(W/H/T) positions.
+ """
+ ...
+
+
+@overload
+def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
+
+
+@overload
+def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
+
+
+def supports_xdrope(
+ model: type[object] | object,
+) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
+ return isinstance(model, SupportsXDRoPE)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index b3da64af750c7..a0d8a78a2ae76 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -287,6 +287,10 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration",
),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
+ "HunYuanVLForConditionalGeneration": (
+ "hunyuan_vision",
+ "HunYuanVLForConditionalGeneration",
+ ),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
"OpenCUAForConditionalGeneration": (
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 3d282da8c6112..c1880a3fba0ee 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -86,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
+ hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@@ -549,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config)
+def uses_xdrope_dim(config: PretrainedConfig) -> int:
+ """Detect if the model with this config uses XD-ROPE."""
+ xdrope_section = getattr(config, "xdrope_section", None)
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is None:
+ return 0
+
+ if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
+ xdrope_section = rope_scaling["xdrope_section"]
+ if xdrope_section is not None and isinstance(xdrope_section, list):
+ return len(xdrope_section)
+
+ return 0
+
+
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index d28fd8d033373..109f2b6986514 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
+from vllm.transformers_utils.configs.hunyuan_vl import (
+ HunYuanVLConfig,
+ HunYuanVLTextConfig,
+ HunYuanVLVisionConfig,
+)
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@@ -53,6 +58,9 @@ __all__ = [
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
+ "HunYuanVLConfig",
+ "HunYuanVLTextConfig",
+ "HunYuanVLVisionConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",
diff --git a/vllm/transformers_utils/configs/hunyuan_vl.py b/vllm/transformers_utils/configs/hunyuan_vl.py
new file mode 100644
index 0000000000000..a826ed9b5155d
--- /dev/null
+++ b/vllm/transformers_utils/configs/hunyuan_vl.py
@@ -0,0 +1,322 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
+
+from transformers import PretrainedConfig
+
+
+class HunYuanVLVisionConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_act="gelu",
+ hidden_size=1152,
+ intermediate_size=4304,
+ interpolate_mode="bilinear",
+ rms_norm_eps=1e-05,
+ learnable_mlp_pooling_size=0,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ num_channels=3,
+ num_hidden_layers=27,
+ out_hidden_size=4096,
+ patch_size=16,
+ remove_prenorm=True,
+ spatial_merge_size=2,
+ temporal_patch_size=1,
+ resize_resolution=2048,
+ img_max_token_num=4096,
+ max_image_size=2048,
+ video_max_image_size=768,
+ video_min_image_size=256,
+ min_image_size=512,
+ anyres_vit_max_image_size=2048,
+ max_vit_seq_len=16384,
+ text_hidden_size=3072,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_act = hidden_act
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.interpolate_mode = interpolate_mode
+ self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
+ self.num_attention_heads = num_attention_heads
+ if not num_key_value_heads:
+ self.num_key_value_heads = num_attention_heads
+ else:
+ self.num_key_value_heads = num_key_value_heads
+ self.num_channels = num_channels
+ self.num_hidden_layers = num_hidden_layers
+ self.out_hidden_size = out_hidden_size
+ self.patch_size = patch_size
+ self.remove_prenorm = remove_prenorm
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.rms_norm_eps = rms_norm_eps
+
+ self.resize_resolution = resize_resolution
+ self.img_max_token_num = img_max_token_num
+ self.max_image_size = max_image_size
+ self.min_image_size = min_image_size
+ self.video_max_image_size = video_max_image_size
+ self.video_min_image_size = video_min_image_size
+ self.anyres_vit_max_image_size = anyres_vit_max_image_size
+ self.max_vit_seq_len = max_vit_seq_len
+ self.text_hidden_size = text_hidden_size
+
+
+class HunYuanVLTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
+ HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the HunYuan-7B.
+ Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 290943):
+ Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`HunYuanVLTextConfig`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations or shared MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ eod_token_id (int, *optional*, defaults to 3):
+ Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
+ Example: In multi-document processing, this token helps the model distinguish between separate documents.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ head_dim (`int`, *optional*, defaults to 128):
+ The attention head dimension.
+ """ # noqa: E501
+
+ model_type = "hunyuan_vl_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=290943,
+ hidden_size=4096,
+ intermediate_size: int = 11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ eod_token_id=3,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # self._rope_scaling_validation() # TODO: Need validation?
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and "
+ f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ rope_scaling_alpha = self.rope_scaling.get("alpha", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
+ f"got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None and rope_scaling_alpha is None:
+ raise ValueError(
+ "`rope_scaling`'s factor or alpha field must be have one, "
+ "got both of none"
+ )
+ if rope_scaling_factor is not None and (
+ not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s factor field must be a float > 1.0, "
+ f"got {rope_scaling_factor}"
+ )
+ if rope_scaling_alpha is not None and (
+ not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
+ ):
+ raise ValueError(
+ "`rope_scaling`'s alpha field must be a float > 1.0, "
+ f"got {rope_scaling_alpha}"
+ )
+
+
+class HunYuanVLConfig(PretrainedConfig):
+ model_type = "hunyuan_vl"
+ sub_configs = {
+ "vision_config": HunYuanVLVisionConfig,
+ "text_config": HunYuanVLTextConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ im_start_id=120118,
+ im_end_id=120119,
+ image_token_id=120120,
+ im_newline_id=120121,
+ video_start_id=120122,
+ video_end_id=120123,
+ **kwargs,
+ ):
+ # We need to init super() here so that it does not reset values
+ # that are in text config to the BaseClass defaults. The Base
+ # config has many text related defaults and not all defaults are
+ # same as for `HunYuanVLTextConfig`.
+ super().__init__(**kwargs)
+
+ if isinstance(vision_config, dict):
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ self.vision_config = self.sub_configs["vision_config"]()
+
+ if isinstance(text_config, dict):
+ self.text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ # For BC use all kwargs to init `TextConfig`
+ self.text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.im_start_id = im_start_id
+ self.im_end_id = im_end_id
+ self.im_newline_id = im_newline_id
+ self.video_start_id = video_start_id
+ self.video_end_id = video_end_id
+
+ self.vision_config.text_hidden_size = self.text_config.hidden_size
+
+ # Attention implementation to use. It sets it recursively on sub-configs
+ # so we call it again in the end.
+ self._attn_implementation = kwargs.pop("attn_implementation", None)
+
+ def __setattr__(self, key, value):
+ if (
+ (text_config := super().__getattribute__("__dict__").get("text_config"))
+ is not None
+ and key not in ["dtype", "_attn_implementation_internal"]
+ and key in text_config.__dict__
+ ):
+ setattr(text_config, key, value)
+ else:
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
+ "_name_or_path",
+ "model_type",
+ "dtype",
+ "_attn_implementation_internal",
+ ]:
+ text_config = super().__getattribute__("text_config")
+ if key in text_config.__dict__:
+ return getattr(text_config, key)
+
+ return super().__getattribute__(key)
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
index 76b6d3dc9c99a..b49fdbe9ce776 100644
--- a/vllm/transformers_utils/processors/__init__.py
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -9,7 +9,15 @@ reasons:
"""
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
+from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
+from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
-__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
+__all__ = [
+ "DeepseekVLV2Processor",
+ "HunYuanVLProcessor",
+ "HunYuanVLImageProcessor",
+ "OvisProcessor",
+ "Ovis2_5Processor",
+]
diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py
new file mode 100644
index 0000000000000..615a8bff85912
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl.py
@@ -0,0 +1,233 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
+
+import numpy as np
+import torch
+from transformers import AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.video_utils import VideoInput
+
+
+class HunYuanVLProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ video_processor=None,
+ chat_template=None,
+ **kwargs,
+ ):
+ # TODO Fix the init
+ self.tokenizer = tokenizer
+ self.image_token_id = 120120 # self.tokenizer.image_token_id
+ self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
+ self.im_start_token_id = 120118 # self.tokenizer.im_start_id
+ self.im_start_token = self.tokenizer.convert_ids_to_tokens(
+ self.im_start_token_id
+ )
+ self.im_end_token_id = 120119 # self.tokenizer.im_end_id
+ self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
+ self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
+ self.tokenizer.vocab_size - 1
+ )
+ self.pad_id = 120002 # self.tokenizer.pad_token_id
+
+ super().__init__(
+ image_processor, tokenizer, video_processor, chat_template=chat_template
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: TextInput
+ | PreTokenizedInput
+ | list[TextInput]
+ | list[PreTokenizedInput] = None,
+ videos: VideoInput = None,
+ **kwargs,
+ ) -> BatchFeature:
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images=images)
+ image_grid_thw = image_inputs["image_grid_thw"]
+
+ if not isinstance(text, list):
+ text = [text]
+
+ text = text.copy() # below lines change text in-place
+
+ image_tokens_cumsum = [0]
+ if images is not None:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid_h, grid_w = image_grid_thw[index][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ num_image_tokens = patch_h * (patch_w + 1) + 2
+ image_tokens_cumsum.append(
+ image_tokens_cumsum[-1] + num_image_tokens
+ )
+ # text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
+ text[i] = text[i].replace(
+ self.image_token, self.placeholder_token * num_image_tokens, 1
+ )
+ index += 1
+ text[i] = text[i].replace(self.placeholder_token, self.image_token)
+ # text[i] = self.tokenizer.bos_token + text[i]
+
+ text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ input_ids = text_inputs["input_ids"]
+ position_ids = torch.arange(len(input_ids[0]))
+ position_ids_w = torch.arange(len(input_ids[0]))
+ position_ids_h = torch.arange(len(input_ids[0]))
+ position_ids_t = torch.arange(len(input_ids[0]))
+
+ if images is not None:
+ image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
+ 0
+ ]
+ for i in range(len(image_grid_thw)):
+ grid_h, grid_w = image_grid_thw[i][-2:]
+ patch_h = grid_h // self.image_processor.merge_size
+ patch_w = grid_w // self.image_processor.merge_size
+ start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
+ replace_num = (patch_w + 1) * patch_h
+ position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
+ list(range(patch_w + 1)) * patch_h, dtype=torch.int64
+ )
+ patch_h_list = []
+ for h in range(patch_h):
+ patch_h_list += [h] * (patch_w + 1)
+ position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
+ patch_h_list, dtype=torch.int64
+ )
+ position_ids_t[start_pos : start_pos + replace_num] = 0
+
+ position_ids = torch.stack(
+ [position_ids, position_ids_w, position_ids_h, position_ids_t]
+ ).unsqueeze(0)
+ text_inputs["position_ids"] = position_ids
+
+ attention_mask = input_ids.ne(self.pad_id)
+ text_inputs["attention_mask"] = attention_mask
+ text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
+ # image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
+
+ return_tensors = kwargs.pop("return_tensors", None)
+ return BatchFeature(
+ data={**text_inputs, **image_inputs},
+ tensor_type=return_tensors,
+ )
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_image_text_to_text(
+ self,
+ generated_outputs,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ assert 0
+
+ def apply_chat_template(self, *args, **kwargs):
+ token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
+ return token_ids
+
+ def get_imgs_pos(self, doc_ids):
+ doc_ids = np.array(doc_ids, dtype=np.int64)
+ img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
+ img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
+ imgs_pos = np.concatenate(
+ (
+ np.reshape(img_begin_index + 1, (-1, 1)),
+ np.reshape(img_end_index, (-1, 1)),
+ ),
+ axis=-1,
+ ).tolist()
+ return imgs_pos
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+def split_image_into_patch_blocks(
+ pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
+ patch_size: int = 16, # e.g. 16
+ adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
+) -> torch.Tensor:
+ """
+ Split the input image tensor (supporting batch) into large patches of size `patch_size`,
+ and then further divide each large patch into smaller regions of size
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
+ Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
+ The final output contains all such small region tensors.
+
+ Args:
+ pixel_values: Input image tensor of shape [batch_size, 3, H, W].
+ patch_size: Size of the large patch, e.g., 16.
+ adaptor_patch_div: Each large patch is divided into
+ (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
+ smaller regions.
+
+ Returns:
+ patches: A tensor of shape [N, 3, patch_size, patch_size],
+ where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
+ Each element in the batch corresponds to one small image region.
+ """ # noqa: E501
+ batch_size, channels, height, width = pixel_values.shape
+ assert channels == 3, "Pixel values must have 3 channels in dim=1"
+ assert height % patch_size == 0 and width % patch_size == 0, (
+ "H and W must be divisible by patch_size"
+ )
+
+ patch_height_num = height // patch_size
+ patch_width_num = width // patch_size
+
+ # Reshape to [B, 3, ph, ps, pw, ps]
+ img = pixel_values.reshape(
+ batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
+ )
+
+ # Further split each psxps patch into (ps//aps)x(ps//aps) small regions
+ img = img.reshape(
+ batch_size,
+ 3,
+ patch_height_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ patch_width_num,
+ patch_size // adaptor_patch_div, # ps // aps
+ adaptor_patch_div,
+ )
+
+ # Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
+ img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
+
+ # Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
+ patches = img.reshape(-1, 3, patch_size, patch_size)
+
+ return patches
+
+
+AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)
diff --git a/vllm/transformers_utils/processors/hunyuan_vl_image.py b/vllm/transformers_utils/processors/hunyuan_vl_image.py
new file mode 100644
index 0000000000000..0a7e7865c783a
--- /dev/null
+++ b/vllm/transformers_utils/processors/hunyuan_vl_image.py
@@ -0,0 +1,477 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
+"""Image processor class for HunYuanVL."""
+
+# isort conflicts with ruff for transformers imports
+# isort: skip_file
+import math
+
+import numpy as np
+import torchvision.transforms as transforms
+from transformers import AutoImageProcessor
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_transforms import (
+ convert_to_rgb,
+)
+from transformers.image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ make_flat_list_of_images,
+ make_list_of_images,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from transformers.utils import TensorType, logging
+from transformers.video_utils import VideoInput, make_batched_videos
+
+logger = logging.get_logger(__name__)
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ "absolute aspect ratio must be smaller than 200, got "
+ f"{max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class HunYuanVLImageProcessor(BaseImageProcessor):
+ model_input_names = [
+ "pixel_values",
+ "image_grid_thw",
+ "pixel_values_videos",
+ "video_grid_thw",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: int | float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None and (
+ "shortest_edge" not in size or "longest_edge" not in size
+ ):
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ else:
+ size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ # hard-code
+
+ def _preprocess(
+ self,
+ images: ImageInput | VideoInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int = 16,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ do_convert_rgb: bool | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """ # noqa: E501
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ width, height = images[0].width, images[0].height
+ resized_width, resized_height = width, height
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_width, resized_height = smart_resize(
+ width,
+ height,
+ factor=patch_size * merge_size,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ if do_normalize:
+ image = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(self.image_mean, self.image_std),
+ ]
+ )(image)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ patches = patches.reshape(
+ 1,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
+ flatten_patches = patches.reshape(
+ 1 * grid_h * grid_w, channel * patch_size * patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ return_tensors: str | TensorType | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ videos (`VideoInput`):
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """ # noqa: E501
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError(
+ "size must contain 'shortest_edge' and 'longest_edge' keys."
+ )
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels
+ # if they are provided.
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = (
+ rescale_factor if rescale_factor is not None else self.rescale_factor
+ )
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = (
+ temporal_patch_size
+ if temporal_patch_size is not None
+ else self.temporal_patch_size
+ )
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = (
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ )
+
+ if images is not None:
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ if images is not None:
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update(
+ {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
+ )
+
+ # kept for BC only and should be removed after v5.0
+ if videos is not None:
+ logger.warning(
+ "`HunYuanVLV1ImageProcessor` works only with image inputs "
+ "and doesn't process videos anymore. "
+ "This is a deprecated behavior and will be removed in v5.0. "
+ "Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
+ )
+ videos = make_batched_videos(videos)
+ pixel_values_videos, vision_grid_thws_videos = [], []
+ for images in videos:
+ patches, video_grid_thw = self._preprocess(
+ images,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values_videos.extend(patches)
+ vision_grid_thws_videos.append(video_grid_thw)
+ data.update(
+ {
+ "pixel_values_videos": np.array(pixel_values_videos),
+ "video_grid_thw": np.array(vision_grid_thws_videos),
+ }
+ )
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*):
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = (
+ images_kwargs["min_pixels"]
+ if "min_pixels" in images_kwargs
+ else self.size["shortest_edge"]
+ )
+ max_pixels = (
+ images_kwargs["max_pixels"]
+ if "max_pixels" in images_kwargs
+ else self.size["longest_edge"]
+ )
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * (grid_w + 1) + 2
+
+
+AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 4a2818ab1bfd8..e7991baeaa1b8 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
+ xdrope_positions: torch.Tensor | None = None
+
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 6a83ac14e0b3f..6413be66b141c 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+from vllm.model_executor.layers.rotary_embedding import (
+ MRotaryEmbedding,
+ XDRotaryEmbedding,
+)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
SupportsMRoPE,
SupportsMultiModal,
+ SupportsXDRoPE,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
+ supports_xdrope,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
@@ -324,6 +329,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
+ self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@@ -512,6 +518,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
+ self.xdrope_positions = self._make_buffer(
+ (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
+ )
+
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
@@ -593,10 +606,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
+ if self.uses_xdrope_dim > 0:
+ return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
@@ -772,6 +789,10 @@ class GPUModelRunner(
if self.uses_mrope:
self._init_mrope_positions(req_state)
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._init_xdrope_positions(req_state)
+
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
@@ -987,6 +1008,19 @@ class GPUModelRunner(
)
)
+ def _init_xdrope_positions(self, req_state: CachedRequestState):
+ model = self.get_model()
+ xdrope_model = cast(SupportsXDRoPE, model)
+ assert req_state.prompt_token_ids is not None, (
+ "XD-RoPE requires prompt_token_ids to be available."
+ )
+ assert supports_xdrope(model), "XD-RoPE support is not implemented."
+
+ req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
+ req_state.prompt_token_ids,
+ req_state.mm_features,
+ )
+
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
@@ -1231,6 +1265,11 @@ class GPUModelRunner(
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
+ # Calculate XD-RoPE positions.
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ if self.uses_xdrope_dim > 0:
+ self._calc_xdrope_positions(scheduler_output)
+
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -1364,6 +1403,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
+ elif self.uses_xdrope_dim > 0:
+ # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
+ self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
+ self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
+ non_blocking=True,
+ )
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
@@ -1793,6 +1838,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len
+ def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
+ xdrope_pos_ptr = 0
+ for index, req_id in enumerate(self.input_batch.req_ids):
+ req = self.requests[req_id]
+ assert req.xdrope_positions is not None
+
+ num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
+ req.prompt_token_ids, req.prompt_embeds
+ )
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's xdrope_positions are pre-computed
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + prompt_part_len
+ src_start = num_computed_tokens
+ src_end = num_computed_tokens + prompt_part_len
+
+ self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
+ :, src_start:src_end
+ ]
+ xdrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's xdrope_positions on-the-fly
+ dst_start = xdrope_pos_ptr
+ dst_end = xdrope_pos_ptr + completion_part_len
+
+ XDRotaryEmbedding.get_next_input_positions_tensor(
+ out=self.xdrope_positions.np,
+ out_offset=dst_start,
+ context_len=num_computed_tokens + prompt_part_len,
+ num_new_tokens=completion_part_len,
+ )
+
+ xdrope_pos_ptr += completion_part_len
+
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@@ -2037,6 +2129,7 @@ class GPUModelRunner(
req_start_idx = 0
should_sync_mrope_positions = False
+ should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@@ -2110,6 +2203,10 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+ if should_sync_xdrope_positions:
+ self._calc_xdrope_positions(scheduler_output)
+ self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
+
return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module:
@@ -2384,8 +2481,11 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
+
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@@ -3824,6 +3924,8 @@ class GPUModelRunner(
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
+ elif self.uses_xdrope_dim > 0:
+ positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
else:
positions = self.positions.gpu[:num_tokens_after_padding]