[Model] Add HunyuanOCR support (#29327)

Signed-off-by: manayang <jackmanayang@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: sergeywang <sergeywang@tencent.com>
Co-authored-by: manayang <jackmanayang@gmail.com>
Co-authored-by: manayang <manayang@tencent.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-11-25 11:28:51 +08:00 committed by GitHub
parent 87185c88d5
commit 92effb07a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2415 additions and 4 deletions

View File

@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | | `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |

View File

@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
) )
# HunyuanOCR
def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "tencent/HunyuanOCR"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
)
placeholder = "<hy_place▁holder▁no▁100><hy_place▁holder▁no▁102><hy_place▁holder▁no▁101>" # noqa: E501
prompts = [
f"<hy_begin▁of▁sentence>{placeholder}{question}<hy_User>"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=None,
)
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B # naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision( def run_hyperclovax_seed_vision(
questions: list[str], modality: str questions: list[str], modality: str
@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v, "glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8, "glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl, "h2ovl_chat": run_h2ovl,
"hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision, "hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3, "idefics3": run_idefics3,
"interns1": run_interns1, "interns1": run_interns1,

View File

@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True, trust_remote_code=True,
), ),
"HunYuanVLForConditionalGeneration": _HfExamplesInfo(
"tencent/HunyuanOCR",
is_available_online=False,
),
"Idefics3ForConditionalGeneration": _HfExamplesInfo( "Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},

View File

@ -33,6 +33,7 @@ from vllm.transformers_utils.config import (
try_get_safetensors_metadata, try_get_safetensors_metadata,
try_get_tokenizer_config, try_get_tokenizer_config,
uses_mrope, uses_mrope,
uses_xdrope_dim,
) )
from vllm.transformers_utils.gguf_utils import ( from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf, maybe_patch_hf_config_from_gguf,
@ -1615,6 +1616,10 @@ class ModelConfig:
def uses_mrope(self) -> bool: def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config) return uses_mrope(self.hf_config)
@property
def uses_xdrope_dim(self) -> int:
return uses_xdrope_dim(self.hf_config)
@property @property
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None return self.multimodal_config is not None

View File

@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@ -184,6 +185,18 @@ def get_rope(
raise ValueError( raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field" "Dynamic rope scaling must contain either 'alpha' or 'factor' field"
) )
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"] original_max_position = rope_parameters["original_max_position_embeddings"]

View File

@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
"""DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
xdrope_section: list[int],
) -> None:
self.xdrope_section = xdrope_section
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
scaling_alpha,
dtype,
)
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_next_input_positions(
context_len: int,
seq_len: int,
xd_sections: int = 4,
) -> list[list[int]]:
return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
@staticmethod
def get_next_input_positions_tensor(
out: np.ndarray,
out_offset: int,
context_len: int,
num_new_tokens: int,
):
values = np.arange(
context_len,
context_len + num_new_tokens,
dtype=out.dtype,
)
out[:, out_offset : out_offset + num_new_tokens] = values

View File

@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states return hidden_states, residual, ori_kv_states
@support_torch_compile @support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
)
class HunYuanModel(nn.Module): class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()

File diff suppressed because it is too large Load Diff

View File

@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True supports_mrope: ClassVar[Literal[True]] = True
""" """
A flag that indicates this model supports M-RoPE. A flag that indicates this model supports M-RoPE.
Note: Note:
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
MRO of your model class. MRO of your model class.
@ -1088,3 +1088,52 @@ def supports_mrope(
model: type[object] | object, model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]: ) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE) return isinstance(model, SupportsMRoPE)
@runtime_checkable
class SupportsXDRoPE(Protocol):
"""The interface required for all models that support XD-RoPE."""
supports_xdrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports XD-RoPE.
Note:
There is no need to redefine this flag if this class is in the
XDRope of your model class.
"""
def get_xdrope_input_positions(
self,
input_tokens: list[int],
mm_features: list["MultiModalFeatureSpec"],
) -> torch.Tensor:
"""
Get XD-RoPE input positions and delta value for this specific model.
This method should be implemented by each model that supports XD-RoPE
to provide model-specific logic for computing input positions.
Args:
input_tokens: List of input token IDs
mm_features: Information about each multi-modal data item
Returns:
llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
4D(P/W/H/T) or 3D(W/H/T) positions.
"""
...
@overload
def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
@overload
def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
def supports_xdrope(
model: type[object] | object,
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
return isinstance(model, SupportsXDRoPE)

View File

@ -287,6 +287,10 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration", "GraniteSpeechForConditionalGeneration",
), ),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"HunYuanVLForConditionalGeneration": (
"hunyuan_vision",
"HunYuanVLForConditionalGeneration",
),
"InternVLChatModel": ("internvl", "InternVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
"OpenCUAForConditionalGeneration": ( "OpenCUAForConditionalGeneration": (

View File

@ -86,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config", deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config", deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig", flex_olmo="FlexOlmoConfig",
hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig", kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig", kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@ -549,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config) return uses_mrope(thinker_text_config)
def uses_xdrope_dim(config: PretrainedConfig) -> int:
"""Detect if the model with this config uses XD-ROPE."""
xdrope_section = getattr(config, "xdrope_section", None)
if xdrope_section is not None and isinstance(xdrope_section, list):
return len(xdrope_section)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return 0
if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
xdrope_section = rope_scaling["xdrope_section"]
if xdrope_section is not None and isinstance(xdrope_section, list):
return len(xdrope_section)
return 0
def is_encoder_decoder(config: PretrainedConfig) -> bool: def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder.""" """Detect if the model with this config is used as an encoder/decoder."""

View File

@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.hunyuan_vl import (
HunYuanVLConfig,
HunYuanVLTextConfig,
HunYuanVLVisionConfig,
)
from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@ -53,6 +58,9 @@ __all__ = [
"DotsOCRConfig", "DotsOCRConfig",
"EAGLEConfig", "EAGLEConfig",
"FlexOlmoConfig", "FlexOlmoConfig",
"HunYuanVLConfig",
"HunYuanVLTextConfig",
"HunYuanVLVisionConfig",
"RWConfig", "RWConfig",
"JAISConfig", "JAISConfig",
"Lfm2MoeConfig", "Lfm2MoeConfig",

View File

@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
from transformers import PretrainedConfig
class HunYuanVLVisionConfig(PretrainedConfig):
model_type = "hunyuan_vl"
base_config_key = "vision_config"
def __init__(
self,
hidden_act="gelu",
hidden_size=1152,
intermediate_size=4304,
interpolate_mode="bilinear",
rms_norm_eps=1e-05,
learnable_mlp_pooling_size=0,
num_attention_heads=16,
num_key_value_heads=None,
num_channels=3,
num_hidden_layers=27,
out_hidden_size=4096,
patch_size=16,
remove_prenorm=True,
spatial_merge_size=2,
temporal_patch_size=1,
resize_resolution=2048,
img_max_token_num=4096,
max_image_size=2048,
video_max_image_size=768,
video_min_image_size=256,
min_image_size=512,
anyres_vit_max_image_size=2048,
max_vit_seq_len=16384,
text_hidden_size=3072,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.interpolate_mode = interpolate_mode
self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
self.num_attention_heads = num_attention_heads
if not num_key_value_heads:
self.num_key_value_heads = num_attention_heads
else:
self.num_key_value_heads = num_key_value_heads
self.num_channels = num_channels
self.num_hidden_layers = num_hidden_layers
self.out_hidden_size = out_hidden_size
self.patch_size = patch_size
self.remove_prenorm = remove_prenorm
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.rms_norm_eps = rms_norm_eps
self.resize_resolution = resize_resolution
self.img_max_token_num = img_max_token_num
self.max_image_size = max_image_size
self.min_image_size = min_image_size
self.video_max_image_size = video_max_image_size
self.video_min_image_size = video_min_image_size
self.anyres_vit_max_image_size = anyres_vit_max_image_size
self.max_vit_seq_len = max_vit_seq_len
self.text_hidden_size = text_hidden_size
class HunYuanVLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the HunYuan-7B.
Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 290943):
Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`HunYuanVLTextConfig`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations or shared MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
eod_token_id (int, *optional*, defaults to 3):
Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
Example: In multi-document processing, this token helps the model distinguish between separate documents.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
""" # noqa: E501
model_type = "hunyuan_vl_text"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=290943,
hidden_size=4096,
intermediate_size: int = 11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
eod_token_id=3,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# self._rope_scaling_validation() # TODO: Need validation?
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and "
f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
rope_scaling_alpha = self.rope_scaling.get("alpha", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
f"got {rope_scaling_type}"
)
if rope_scaling_factor is None and rope_scaling_alpha is None:
raise ValueError(
"`rope_scaling`'s factor or alpha field must be have one, "
"got both of none"
)
if rope_scaling_factor is not None and (
not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
):
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1.0, "
f"got {rope_scaling_factor}"
)
if rope_scaling_alpha is not None and (
not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
):
raise ValueError(
"`rope_scaling`'s alpha field must be a float > 1.0, "
f"got {rope_scaling_alpha}"
)
class HunYuanVLConfig(PretrainedConfig):
model_type = "hunyuan_vl"
sub_configs = {
"vision_config": HunYuanVLVisionConfig,
"text_config": HunYuanVLTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
im_start_id=120118,
im_end_id=120119,
image_token_id=120120,
im_newline_id=120121,
video_start_id=120122,
video_end_id=120123,
**kwargs,
):
# We need to init super() here so that it does not reset values
# that are in text config to the BaseClass defaults. The Base
# config has many text related defaults and not all defaults are
# same as for `HunYuanVLTextConfig`.
super().__init__(**kwargs)
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.im_newline_id = im_newline_id
self.video_start_id = video_start_id
self.video_end_id = video_end_id
self.vision_config.text_hidden_size = self.text_config.hidden_size
# Attention implementation to use. It sets it recursively on sub-configs
# so we call it again in the end.
self._attn_implementation = kwargs.pop("attn_implementation", None)
def __setattr__(self, key, value):
if (
(text_config := super().__getattribute__("__dict__").get("text_config"))
is not None
and key not in ["dtype", "_attn_implementation_internal"]
and key in text_config.__dict__
):
setattr(text_config, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, key):
if "text_config" in super().__getattribute__("__dict__") and key not in [
"_name_or_path",
"model_type",
"dtype",
"_attn_implementation_internal",
]:
text_config = super().__getattribute__("text_config")
if key in text_config.__dict__:
return getattr(text_config, key)
return super().__getattribute__(key)

View File

@ -9,7 +9,15 @@ reasons:
""" """
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"] __all__ = [
"DeepseekVLV2Processor",
"HunYuanVLProcessor",
"HunYuanVLImageProcessor",
"OvisProcessor",
"Ovis2_5Processor",
]

View File

@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
import numpy as np
import torch
from transformers import AutoProcessor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
class HunYuanVLProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs,
):
# TODO Fix the init
self.tokenizer = tokenizer
self.image_token_id = 120120 # self.tokenizer.image_token_id
self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
self.im_start_token_id = 120118 # self.tokenizer.im_start_id
self.im_start_token = self.tokenizer.convert_ids_to_tokens(
self.im_start_token_id
)
self.im_end_token_id = 120119 # self.tokenizer.im_end_id
self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.vocab_size - 1
)
self.pad_id = 120002 # self.tokenizer.pad_token_id
super().__init__(
image_processor, tokenizer, video_processor, chat_template=chat_template
)
def __call__(
self,
images: ImageInput = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput] = None,
videos: VideoInput = None,
**kwargs,
) -> BatchFeature:
image_inputs = {}
if images is not None:
image_inputs = self.image_processor(images=images)
image_grid_thw = image_inputs["image_grid_thw"]
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
image_tokens_cumsum = [0]
if images is not None:
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
grid_h, grid_w = image_grid_thw[index][-2:]
patch_h = grid_h // self.image_processor.merge_size
patch_w = grid_w // self.image_processor.merge_size
num_image_tokens = patch_h * (patch_w + 1) + 2
image_tokens_cumsum.append(
image_tokens_cumsum[-1] + num_image_tokens
)
# text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
text[i] = text[i].replace(
self.image_token, self.placeholder_token * num_image_tokens, 1
)
index += 1
text[i] = text[i].replace(self.placeholder_token, self.image_token)
# text[i] = self.tokenizer.bos_token + text[i]
text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
input_ids = text_inputs["input_ids"]
position_ids = torch.arange(len(input_ids[0]))
position_ids_w = torch.arange(len(input_ids[0]))
position_ids_h = torch.arange(len(input_ids[0]))
position_ids_t = torch.arange(len(input_ids[0]))
if images is not None:
image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
0
]
for i in range(len(image_grid_thw)):
grid_h, grid_w = image_grid_thw[i][-2:]
patch_h = grid_h // self.image_processor.merge_size
patch_w = grid_w // self.image_processor.merge_size
start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
replace_num = (patch_w + 1) * patch_h
position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
list(range(patch_w + 1)) * patch_h, dtype=torch.int64
)
patch_h_list = []
for h in range(patch_h):
patch_h_list += [h] * (patch_w + 1)
position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
patch_h_list, dtype=torch.int64
)
position_ids_t[start_pos : start_pos + replace_num] = 0
position_ids = torch.stack(
[position_ids, position_ids_w, position_ids_h, position_ids_t]
).unsqueeze(0)
text_inputs["position_ids"] = position_ids
attention_mask = input_ids.ne(self.pad_id)
text_inputs["attention_mask"] = attention_mask
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors = kwargs.pop("return_tensors", None)
return BatchFeature(
data={**text_inputs, **image_inputs},
tensor_type=return_tensors,
)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self,
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
**kwargs,
):
assert 0
def apply_chat_template(self, *args, **kwargs):
token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
return token_ids
def get_imgs_pos(self, doc_ids):
doc_ids = np.array(doc_ids, dtype=np.int64)
img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
imgs_pos = np.concatenate(
(
np.reshape(img_begin_index + 1, (-1, 1)),
np.reshape(img_end_index, (-1, 1)),
),
axis=-1,
).tolist()
return imgs_pos
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def split_image_into_patch_blocks(
pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
patch_size: int = 16, # e.g. 16
adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
) -> torch.Tensor:
"""
Split the input image tensor (supporting batch) into large patches of size `patch_size`,
and then further divide each large patch into smaller regions of size
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
The final output contains all such small region tensors.
Args:
pixel_values: Input image tensor of shape [batch_size, 3, H, W].
patch_size: Size of the large patch, e.g., 16.
adaptor_patch_div: Each large patch is divided into
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
smaller regions.
Returns:
patches: A tensor of shape [N, 3, patch_size, patch_size],
where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
Each element in the batch corresponds to one small image region.
""" # noqa: E501
batch_size, channels, height, width = pixel_values.shape
assert channels == 3, "Pixel values must have 3 channels in dim=1"
assert height % patch_size == 0 and width % patch_size == 0, (
"H and W must be divisible by patch_size"
)
patch_height_num = height // patch_size
patch_width_num = width // patch_size
# Reshape to [B, 3, ph, ps, pw, ps]
img = pixel_values.reshape(
batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
)
# Further split each psxps patch into (ps//aps)x(ps//aps) small regions
img = img.reshape(
batch_size,
3,
patch_height_num,
patch_size // adaptor_patch_div, # ps // aps
adaptor_patch_div,
patch_width_num,
patch_size // adaptor_patch_div, # ps // aps
adaptor_patch_div,
)
# Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
# Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
patches = img.reshape(-1, 3, patch_size, patch_size)
return patches
AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)

View File

@ -0,0 +1,477 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
"""Image processor class for HunYuanVL."""
# isort conflicts with ruff for transformers imports
# isort: skip_file
import math
import numpy as np
import torchvision.transforms as transforms
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
convert_to_rgb,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
make_list_of_images,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import TensorType, logging
from transformers.video_utils import VideoInput, make_batched_videos
logger = logging.get_logger(__name__)
def smart_resize(
height: int,
width: int,
factor: int = 16,
min_pixels: int = 512 * 512,
max_pixels: int = 2048 * 2048,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
"absolute aspect ratio must be smaller than 200, got "
f"{max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class HunYuanVLImageProcessor(BaseImageProcessor):
model_input_names = [
"pixel_values",
"image_grid_thw",
"pixel_values_videos",
"video_grid_thw",
]
def __init__(
self,
do_resize: bool = True,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: int | float = 1 / 255,
do_normalize: bool = True,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
min_pixels: int | None = None,
max_pixels: int | None = None,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
**kwargs,
) -> None:
super().__init__(**kwargs)
if size is not None and (
"shortest_edge" not in size or "longest_edge" not in size
):
raise ValueError(
"size must contain 'shortest_edge' and 'longest_edge' keys."
)
else:
size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
# backward compatibility: override size with min_pixels and max_pixels
# if they are provided.
if min_pixels is not None:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
size["longest_edge"] = max_pixels
self.min_pixels = size["shortest_edge"]
self.max_pixels = size["longest_edge"]
self.size = size
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.do_convert_rgb = do_convert_rgb
# hard-code
def _preprocess(
self,
images: ImageInput | VideoInput,
do_resize: bool | None = None,
size: dict[str, int] | None = None,
resample: PILImageResampling = None,
do_rescale: bool | None = None,
rescale_factor: float | None = None,
do_normalize: bool | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
do_convert_rgb: bool | None = None,
data_format: ChannelDimension | None = ChannelDimension.FIRST,
input_data_format: str | ChannelDimension | None = None,
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to `self.merge_size`):
The merge size of the vision encoder to llm encoder.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" # noqa: E501
images = make_list_of_images(images)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
width, height = images[0].width, images[0].height
resized_width, resized_height = width, height
processed_images = []
for image in images:
if do_resize:
resized_width, resized_height = smart_resize(
width,
height,
factor=patch_size * merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = image.resize((resized_width, resized_height))
if do_normalize:
image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(self.image_mean, self.image_std),
]
)(image)
processed_images.append(image)
patches = np.array(processed_images)
channel = patches.shape[1]
grid_t = patches.shape[0] // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.reshape(
1,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
flatten_patches = patches.reshape(
1 * grid_h * grid_w, channel * patch_size * patch_size
)
return flatten_patches, (grid_t, grid_h, grid_w)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
do_resize: bool | None = None,
size: dict[str, int] | None = None,
min_pixels: int | None = None,
max_pixels: int | None = None,
resample: PILImageResampling = None,
do_rescale: bool | None = None,
rescale_factor: float | None = None,
do_normalize: bool | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
patch_size: int | None = None,
temporal_patch_size: int | None = None,
merge_size: int | None = None,
do_convert_rgb: bool | None = None,
return_tensors: str | TensorType | None = None,
data_format: ChannelDimension | None = ChannelDimension.FIRST,
input_data_format: str | ChannelDimension | None = None,
):
"""
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
videos (`VideoInput`):
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
The min pixels of the image to resize the image.
max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
The max pixels of the image to resize the image.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to `self.merge_size`):
The merge size of the vision encoder to llm encoder.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" # noqa: E501
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
if size is not None:
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError(
"size must contain 'shortest_edge' and 'longest_edge' keys."
)
min_pixels = size["shortest_edge"]
elif min_pixels is not None and max_pixels is not None:
# backward compatibility: override size with min_pixels and max_pixels
# if they are provided.
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
else:
size = {**self.size}
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
patch_size = patch_size if patch_size is not None else self.patch_size
temporal_patch_size = (
temporal_patch_size
if temporal_patch_size is not None
else self.temporal_patch_size
)
merge_size = merge_size if merge_size is not None else self.merge_size
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
if images is not None:
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
data = {}
if images is not None:
pixel_values, vision_grid_thws = [], []
for image in images:
patches, image_grid_thw = self._preprocess(
image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
merge_size=merge_size,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
)
pixel_values.extend(patches)
vision_grid_thws.append(image_grid_thw)
pixel_values = np.array(pixel_values)
vision_grid_thws = np.array(vision_grid_thws)
data.update(
{"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
)
# kept for BC only and should be removed after v5.0
if videos is not None:
logger.warning(
"`HunYuanVLV1ImageProcessor` works only with image inputs "
"and doesn't process videos anymore. "
"This is a deprecated behavior and will be removed in v5.0. "
"Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
)
videos = make_batched_videos(videos)
pixel_values_videos, vision_grid_thws_videos = [], []
for images in videos:
patches, video_grid_thw = self._preprocess(
images,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
merge_size=merge_size,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
)
pixel_values_videos.extend(patches)
vision_grid_thws_videos.append(video_grid_thw)
data.update(
{
"pixel_values_videos": np.array(pixel_values_videos),
"video_grid_thw": np.array(vision_grid_thws_videos),
}
)
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*):
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = (
images_kwargs["min_pixels"]
if "min_pixels" in images_kwargs
else self.size["shortest_edge"]
)
max_pixels = (
images_kwargs["max_pixels"]
if "max_pixels" in images_kwargs
else self.size["longest_edge"]
)
patch_size = images_kwargs.get("patch_size", self.patch_size)
merge_size = images_kwargs.get("merge_size", self.merge_size)
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * (grid_w + 1) + 2
AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)

View File

@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None mrope_position_delta: int | None = None
xdrope_positions: torch.Tensor | None = None
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None prompt_embeds: torch.Tensor | None = None

View File

@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding,
XDRotaryEmbedding,
)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMRoPE, SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsXDRoPE,
is_mixture_of_experts, is_mixture_of_experts,
supports_eagle3, supports_eagle3,
supports_mrope, supports_mrope,
supports_multimodal_pruning, supports_multimodal_pruning,
supports_transcription, supports_transcription,
supports_xdrope,
) )
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, VllmModelForPooling,
@ -324,6 +329,7 @@ class GPUModelRunner(
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config model_config
) )
@ -512,6 +518,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64 (3, self.max_num_tokens + 1), dtype=torch.int64
) )
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
# Similar to mrope but use assigned dimension number for RoPE, 4 as default.
self.xdrope_positions = self._make_buffer(
(self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
)
# None in the first PP rank. The rest are set after load_model. # None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None self.intermediate_tensors: IntermediateTensors | None = None
@ -593,10 +606,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int): if isinstance(num_tokens, int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens] return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens] return self.positions.gpu[:num_tokens]
else: else:
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens] return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens] return self.positions.gpu[num_tokens]
def _make_buffer( def _make_buffer(
@ -772,6 +789,10 @@ class GPUModelRunner(
if self.uses_mrope: if self.uses_mrope:
self._init_mrope_positions(req_state) self._init_mrope_positions(req_state)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
@ -987,6 +1008,19 @@ class GPUModelRunner(
) )
) )
def _init_xdrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
xdrope_model = cast(SupportsXDRoPE, model)
assert req_state.prompt_token_ids is not None, (
"XD-RoPE requires prompt_token_ids to be available."
)
assert supports_xdrope(model), "XD-RoPE support is not implemented."
req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
req_state.prompt_token_ids,
req_state.mm_features,
)
def _extract_mm_kwargs( def _extract_mm_kwargs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -1231,6 +1265,11 @@ class GPUModelRunner(
if self.uses_mrope: if self.uses_mrope:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
# Calculate XD-RoPE positions.
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
self._calc_xdrope_positions(scheduler_output)
# Get token indices. # Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@ -1364,6 +1403,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens], self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True, non_blocking=True,
) )
elif self.uses_xdrope_dim > 0:
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
else: else:
# Common case (1D positions) # Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens) self.positions.copy_to_gpu(total_num_scheduled_tokens)
@ -1793,6 +1838,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len
def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
xdrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id]
assert req.xdrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds
)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0
assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0:
# prompt's xdrope_positions are pre-computed
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
:, src_start:src_end
]
xdrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
# compute completion's xdrope_positions on-the-fly
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + completion_part_len
XDRotaryEmbedding.get_next_input_positions_tensor(
out=self.xdrope_positions.np,
out_offset=dst_start,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
xdrope_pos_ptr += completion_part_len
def _calc_spec_decode_metadata( def _calc_spec_decode_metadata(
self, self,
num_draft_tokens: np.ndarray, num_draft_tokens: np.ndarray,
@ -2037,6 +2129,7 @@ class GPUModelRunner(
req_start_idx = 0 req_start_idx = 0
should_sync_mrope_positions = False should_sync_mrope_positions = False
should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = [] mm_embeds_req: list[torch.Tensor] = []
@ -2110,6 +2203,10 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_xdrope_positions:
self._calc_xdrope_positions(scheduler_output)
self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds, is_mm_embed return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
@ -2384,8 +2481,11 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens] input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens) model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens] positions = self.mrope_positions.gpu[:, :num_input_tokens]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else: else:
positions = self.positions.gpu[:num_input_tokens] positions = self.positions.gpu[:num_input_tokens]
@ -3824,6 +3924,8 @@ class GPUModelRunner(
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
else: else:
positions = self.positions.gpu[:num_tokens_after_padding] positions = self.positions.gpu[:num_tokens_after_padding]