[Model] Add HunyuanOCR support (#29327)

Signed-off-by: manayang <jackmanayang@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: sergeywang <sergeywang@tencent.com>
Co-authored-by: manayang <jackmanayang@gmail.com>
Co-authored-by: manayang <manayang@tencent.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-11-25 11:28:51 +08:00 committed by GitHub
parent 87185c88d5
commit 92effb07a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2415 additions and 4 deletions

View File

@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |

View File

@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
)
# HunyuanOCR
def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "tencent/HunyuanOCR"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
)
placeholder = "<hy_place▁holder▁no▁100><hy_place▁holder▁no▁102><hy_place▁holder▁no▁101>" # noqa: E501
prompts = [
f"<hy_begin▁of▁sentence>{placeholder}{question}<hy_User>"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=None,
)
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
def run_hyperclovax_seed_vision(
questions: list[str], modality: str
@ -1820,6 +1845,7 @@ model_example_map = {
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"h2ovl_chat": run_h2ovl,
"hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
"idefics3": run_idefics3,
"interns1": run_interns1,

View File

@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
"HunYuanVLForConditionalGeneration": _HfExamplesInfo(
"tencent/HunyuanOCR",
is_available_online=False,
),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},

View File

@ -33,6 +33,7 @@ from vllm.transformers_utils.config import (
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_mrope,
uses_xdrope_dim,
)
from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf,
@ -1615,6 +1616,10 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
@property
def uses_xdrope_dim(self) -> int:
return uses_xdrope_dim(self.hf_config)
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

View File

@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
from .xdrope import XDRotaryEmbedding
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
@ -184,6 +185,18 @@ def get_rope(
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "xdrope":
scaling_alpha = rope_parameters["alpha"]
rotary_emb = XDRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
xdrope_section=rope_parameters["xdrope_section"],
)
elif scaling_type == "yarn":
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]

View File

@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
"""DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
xdrope_section: list[int],
) -> None:
self.xdrope_section = xdrope_section
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
scaling_alpha,
dtype,
)
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_next_input_positions(
context_len: int,
seq_len: int,
xd_sections: int = 4,
) -> list[list[int]]:
return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
@staticmethod
def get_next_input_positions_tensor(
out: np.ndarray,
out_offset: int,
context_len: int,
num_new_tokens: int,
):
values = np.arange(
context_len,
context_len + num_new_tokens,
dtype=out.dtype,
)
out[:, out_offset : out_offset + num_new_tokens] = values

View File

@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
return hidden_states, residual, ori_kv_states
@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
)
class HunYuanModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

File diff suppressed because it is too large Load Diff

View File

@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol):
supports_mrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports M-RoPE.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
@ -1088,3 +1088,52 @@ def supports_mrope(
model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
return isinstance(model, SupportsMRoPE)
@runtime_checkable
class SupportsXDRoPE(Protocol):
"""The interface required for all models that support XD-RoPE."""
supports_xdrope: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports XD-RoPE.
Note:
There is no need to redefine this flag if this class is in the
XDRope of your model class.
"""
def get_xdrope_input_positions(
self,
input_tokens: list[int],
mm_features: list["MultiModalFeatureSpec"],
) -> torch.Tensor:
"""
Get XD-RoPE input positions and delta value for this specific model.
This method should be implemented by each model that supports XD-RoPE
to provide model-specific logic for computing input positions.
Args:
input_tokens: List of input token IDs
mm_features: Information about each multi-modal data item
Returns:
llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
4D(P/W/H/T) or 3D(W/H/T) positions.
"""
...
@overload
def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
@overload
def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
def supports_xdrope(
model: type[object] | object,
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
return isinstance(model, SupportsXDRoPE)

View File

@ -287,6 +287,10 @@ _MULTIMODAL_MODELS = {
"GraniteSpeechForConditionalGeneration",
),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"HunYuanVLForConditionalGeneration": (
"hunyuan_vision",
"HunYuanVLForConditionalGeneration",
),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
"OpenCUAForConditionalGeneration": (

View File

@ -86,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
hunyuan_vl="HunYuanVLConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
@ -549,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
return uses_mrope(thinker_text_config)
def uses_xdrope_dim(config: PretrainedConfig) -> int:
"""Detect if the model with this config uses XD-ROPE."""
xdrope_section = getattr(config, "xdrope_section", None)
if xdrope_section is not None and isinstance(xdrope_section, list):
return len(xdrope_section)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return 0
if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
xdrope_section = rope_scaling["xdrope_section"]
if xdrope_section is not None and isinstance(xdrope_section, list):
return len(xdrope_section)
return 0
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""

View File

@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.hunyuan_vl import (
HunYuanVLConfig,
HunYuanVLTextConfig,
HunYuanVLVisionConfig,
)
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
@ -53,6 +58,9 @@ __all__ = [
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
"HunYuanVLConfig",
"HunYuanVLTextConfig",
"HunYuanVLVisionConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",

View File

@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
from transformers import PretrainedConfig
class HunYuanVLVisionConfig(PretrainedConfig):
model_type = "hunyuan_vl"
base_config_key = "vision_config"
def __init__(
self,
hidden_act="gelu",
hidden_size=1152,
intermediate_size=4304,
interpolate_mode="bilinear",
rms_norm_eps=1e-05,
learnable_mlp_pooling_size=0,
num_attention_heads=16,
num_key_value_heads=None,
num_channels=3,
num_hidden_layers=27,
out_hidden_size=4096,
patch_size=16,
remove_prenorm=True,
spatial_merge_size=2,
temporal_patch_size=1,
resize_resolution=2048,
img_max_token_num=4096,
max_image_size=2048,
video_max_image_size=768,
video_min_image_size=256,
min_image_size=512,
anyres_vit_max_image_size=2048,
max_vit_seq_len=16384,
text_hidden_size=3072,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.interpolate_mode = interpolate_mode
self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
self.num_attention_heads = num_attention_heads
if not num_key_value_heads:
self.num_key_value_heads = num_attention_heads
else:
self.num_key_value_heads = num_key_value_heads
self.num_channels = num_channels
self.num_hidden_layers = num_hidden_layers
self.out_hidden_size = out_hidden_size
self.patch_size = patch_size
self.remove_prenorm = remove_prenorm
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.rms_norm_eps = rms_norm_eps
self.resize_resolution = resize_resolution
self.img_max_token_num = img_max_token_num
self.max_image_size = max_image_size
self.min_image_size = min_image_size
self.video_max_image_size = video_max_image_size
self.video_min_image_size = video_min_image_size
self.anyres_vit_max_image_size = anyres_vit_max_image_size
self.max_vit_seq_len = max_vit_seq_len
self.text_hidden_size = text_hidden_size
class HunYuanVLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the HunYuan-7B.
Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 290943):
Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`HunYuanVLTextConfig`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations or shared MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
eod_token_id (int, *optional*, defaults to 3):
Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
Example: In multi-document processing, this token helps the model distinguish between separate documents.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
""" # noqa: E501
model_type = "hunyuan_vl_text"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=290943,
hidden_size=4096,
intermediate_size: int = 11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
eod_token_id=3,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# self._rope_scaling_validation() # TODO: Need validation?
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and "
f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
rope_scaling_alpha = self.rope_scaling.get("alpha", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
f"got {rope_scaling_type}"
)
if rope_scaling_factor is None and rope_scaling_alpha is None:
raise ValueError(
"`rope_scaling`'s factor or alpha field must be have one, "
"got both of none"
)
if rope_scaling_factor is not None and (
not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
):
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1.0, "
f"got {rope_scaling_factor}"
)
if rope_scaling_alpha is not None and (
not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
):
raise ValueError(
"`rope_scaling`'s alpha field must be a float > 1.0, "
f"got {rope_scaling_alpha}"
)
class HunYuanVLConfig(PretrainedConfig):
model_type = "hunyuan_vl"
sub_configs = {
"vision_config": HunYuanVLVisionConfig,
"text_config": HunYuanVLTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
im_start_id=120118,
im_end_id=120119,
image_token_id=120120,
im_newline_id=120121,
video_start_id=120122,
video_end_id=120123,
**kwargs,
):
# We need to init super() here so that it does not reset values
# that are in text config to the BaseClass defaults. The Base
# config has many text related defaults and not all defaults are
# same as for `HunYuanVLTextConfig`.
super().__init__(**kwargs)
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
self.image_token_id = image_token_id
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.im_newline_id = im_newline_id
self.video_start_id = video_start_id
self.video_end_id = video_end_id
self.vision_config.text_hidden_size = self.text_config.hidden_size
# Attention implementation to use. It sets it recursively on sub-configs
# so we call it again in the end.
self._attn_implementation = kwargs.pop("attn_implementation", None)
def __setattr__(self, key, value):
if (
(text_config := super().__getattribute__("__dict__").get("text_config"))
is not None
and key not in ["dtype", "_attn_implementation_internal"]
and key in text_config.__dict__
):
setattr(text_config, key, value)
else:
super().__setattr__(key, value)
def __getattribute__(self, key):
if "text_config" in super().__getattribute__("__dict__") and key not in [
"_name_or_path",
"model_type",
"dtype",
"_attn_implementation_internal",
]:
text_config = super().__getattribute__("text_config")
if key in text_config.__dict__:
return getattr(text_config, key)
return super().__getattribute__(key)

View File

@ -9,7 +9,15 @@ reasons:
"""
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
__all__ = [
"DeepseekVLV2Processor",
"HunYuanVLProcessor",
"HunYuanVLImageProcessor",
"OvisProcessor",
"Ovis2_5Processor",
]

View File

@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
import numpy as np
import torch
from transformers import AutoProcessor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
class HunYuanVLProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs,
):
# TODO Fix the init
self.tokenizer = tokenizer
self.image_token_id = 120120 # self.tokenizer.image_token_id
self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
self.im_start_token_id = 120118 # self.tokenizer.im_start_id
self.im_start_token = self.tokenizer.convert_ids_to_tokens(
self.im_start_token_id
)
self.im_end_token_id = 120119 # self.tokenizer.im_end_id
self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.vocab_size - 1
)
self.pad_id = 120002 # self.tokenizer.pad_token_id
super().__init__(
image_processor, tokenizer, video_processor, chat_template=chat_template
)
def __call__(
self,
images: ImageInput = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput] = None,
videos: VideoInput = None,
**kwargs,
) -> BatchFeature:
image_inputs = {}
if images is not None:
image_inputs = self.image_processor(images=images)
image_grid_thw = image_inputs["image_grid_thw"]
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
image_tokens_cumsum = [0]
if images is not None:
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
grid_h, grid_w = image_grid_thw[index][-2:]
patch_h = grid_h // self.image_processor.merge_size
patch_w = grid_w // self.image_processor.merge_size
num_image_tokens = patch_h * (patch_w + 1) + 2
image_tokens_cumsum.append(
image_tokens_cumsum[-1] + num_image_tokens
)
# text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
text[i] = text[i].replace(
self.image_token, self.placeholder_token * num_image_tokens, 1
)
index += 1
text[i] = text[i].replace(self.placeholder_token, self.image_token)
# text[i] = self.tokenizer.bos_token + text[i]
text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
input_ids = text_inputs["input_ids"]
position_ids = torch.arange(len(input_ids[0]))
position_ids_w = torch.arange(len(input_ids[0]))
position_ids_h = torch.arange(len(input_ids[0]))
position_ids_t = torch.arange(len(input_ids[0]))
if images is not None:
image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
0
]
for i in range(len(image_grid_thw)):
grid_h, grid_w = image_grid_thw[i][-2:]
patch_h = grid_h // self.image_processor.merge_size
patch_w = grid_w // self.image_processor.merge_size
start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
replace_num = (patch_w + 1) * patch_h
position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
list(range(patch_w + 1)) * patch_h, dtype=torch.int64
)
patch_h_list = []
for h in range(patch_h):
patch_h_list += [h] * (patch_w + 1)
position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
patch_h_list, dtype=torch.int64
)
position_ids_t[start_pos : start_pos + replace_num] = 0
position_ids = torch.stack(
[position_ids, position_ids_w, position_ids_h, position_ids_t]
).unsqueeze(0)
text_inputs["position_ids"] = position_ids
attention_mask = input_ids.ne(self.pad_id)
text_inputs["attention_mask"] = attention_mask
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors = kwargs.pop("return_tensors", None)
return BatchFeature(
data={**text_inputs, **image_inputs},
tensor_type=return_tensors,
)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self,
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
**kwargs,
):
assert 0
def apply_chat_template(self, *args, **kwargs):
token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
return token_ids
def get_imgs_pos(self, doc_ids):
doc_ids = np.array(doc_ids, dtype=np.int64)
img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
imgs_pos = np.concatenate(
(
np.reshape(img_begin_index + 1, (-1, 1)),
np.reshape(img_end_index, (-1, 1)),
),
axis=-1,
).tolist()
return imgs_pos
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def split_image_into_patch_blocks(
pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
patch_size: int = 16, # e.g. 16
adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
) -> torch.Tensor:
"""
Split the input image tensor (supporting batch) into large patches of size `patch_size`,
and then further divide each large patch into smaller regions of size
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
The final output contains all such small region tensors.
Args:
pixel_values: Input image tensor of shape [batch_size, 3, H, W].
patch_size: Size of the large patch, e.g., 16.
adaptor_patch_div: Each large patch is divided into
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
smaller regions.
Returns:
patches: A tensor of shape [N, 3, patch_size, patch_size],
where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
Each element in the batch corresponds to one small image region.
""" # noqa: E501
batch_size, channels, height, width = pixel_values.shape
assert channels == 3, "Pixel values must have 3 channels in dim=1"
assert height % patch_size == 0 and width % patch_size == 0, (
"H and W must be divisible by patch_size"
)
patch_height_num = height // patch_size
patch_width_num = width // patch_size
# Reshape to [B, 3, ph, ps, pw, ps]
img = pixel_values.reshape(
batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
)
# Further split each psxps patch into (ps//aps)x(ps//aps) small regions
img = img.reshape(
batch_size,
3,
patch_height_num,
patch_size // adaptor_patch_div, # ps // aps
adaptor_patch_div,
patch_width_num,
patch_size // adaptor_patch_div, # ps // aps
adaptor_patch_div,
)
# Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
# Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
patches = img.reshape(-1, 3, patch_size, patch_size)
return patches
AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)

View File

@ -0,0 +1,477 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
"""Image processor class for HunYuanVL."""
# isort conflicts with ruff for transformers imports
# isort: skip_file
import math
import numpy as np
import torchvision.transforms as transforms
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
convert_to_rgb,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
make_list_of_images,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import TensorType, logging
from transformers.video_utils import VideoInput, make_batched_videos
logger = logging.get_logger(__name__)
def smart_resize(
height: int,
width: int,
factor: int = 16,
min_pixels: int = 512 * 512,
max_pixels: int = 2048 * 2048,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
"absolute aspect ratio must be smaller than 200, got "
f"{max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class HunYuanVLImageProcessor(BaseImageProcessor):
model_input_names = [
"pixel_values",
"image_grid_thw",
"pixel_values_videos",
"video_grid_thw",
]
def __init__(
self,
do_resize: bool = True,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: int | float = 1 / 255,
do_normalize: bool = True,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
min_pixels: int | None = None,
max_pixels: int | None = None,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
**kwargs,
) -> None:
super().__init__(**kwargs)
if size is not None and (
"shortest_edge" not in size or "longest_edge" not in size
):
raise ValueError(
"size must contain 'shortest_edge' and 'longest_edge' keys."
)
else:
size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
# backward compatibility: override size with min_pixels and max_pixels
# if they are provided.
if min_pixels is not None:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
size["longest_edge"] = max_pixels
self.min_pixels = size["shortest_edge"]
self.max_pixels = size["longest_edge"]
self.size = size
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.do_convert_rgb = do_convert_rgb
# hard-code
def _preprocess(
self,
images: ImageInput | VideoInput,
do_resize: bool | None = None,
size: dict[str, int] | None = None,
resample: PILImageResampling = None,
do_rescale: bool | None = None,
rescale_factor: float | None = None,
do_normalize: bool | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
do_convert_rgb: bool | None = None,
data_format: ChannelDimension | None = ChannelDimension.FIRST,
input_data_format: str | ChannelDimension | None = None,
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to `self.merge_size`):
The merge size of the vision encoder to llm encoder.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" # noqa: E501
images = make_list_of_images(images)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
width, height = images[0].width, images[0].height
resized_width, resized_height = width, height
processed_images = []
for image in images:
if do_resize:
resized_width, resized_height = smart_resize(
width,
height,
factor=patch_size * merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = image.resize((resized_width, resized_height))
if do_normalize:
image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(self.image_mean, self.image_std),
]
)(image)
processed_images.append(image)
patches = np.array(processed_images)
channel = patches.shape[1]
grid_t = patches.shape[0] // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.reshape(
1,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
flatten_patches = patches.reshape(
1 * grid_h * grid_w, channel * patch_size * patch_size
)
return flatten_patches, (grid_t, grid_h, grid_w)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
do_resize: bool | None = None,
size: dict[str, int] | None = None,
min_pixels: int | None = None,
max_pixels: int | None = None,
resample: PILImageResampling = None,
do_rescale: bool | None = None,
rescale_factor: float | None = None,
do_normalize: bool | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
patch_size: int | None = None,
temporal_patch_size: int | None = None,
merge_size: int | None = None,
do_convert_rgb: bool | None = None,
return_tensors: str | TensorType | None = None,
data_format: ChannelDimension | None = ChannelDimension.FIRST,
input_data_format: str | ChannelDimension | None = None,
):
"""
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
videos (`VideoInput`):
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
The min pixels of the image to resize the image.
max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
The max pixels of the image to resize the image.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to `self.merge_size`):
The merge size of the vision encoder to llm encoder.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
""" # noqa: E501
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
if size is not None:
if "shortest_edge" not in size or "longest_edge" not in size:
raise ValueError(
"size must contain 'shortest_edge' and 'longest_edge' keys."
)
min_pixels = size["shortest_edge"]
elif min_pixels is not None and max_pixels is not None:
# backward compatibility: override size with min_pixels and max_pixels
# if they are provided.
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
else:
size = {**self.size}
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
patch_size = patch_size if patch_size is not None else self.patch_size
temporal_patch_size = (
temporal_patch_size
if temporal_patch_size is not None
else self.temporal_patch_size
)
merge_size = merge_size if merge_size is not None else self.merge_size
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
if images is not None:
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
data = {}
if images is not None:
pixel_values, vision_grid_thws = [], []
for image in images:
patches, image_grid_thw = self._preprocess(
image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
merge_size=merge_size,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
)
pixel_values.extend(patches)
vision_grid_thws.append(image_grid_thw)
pixel_values = np.array(pixel_values)
vision_grid_thws = np.array(vision_grid_thws)
data.update(
{"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
)
# kept for BC only and should be removed after v5.0
if videos is not None:
logger.warning(
"`HunYuanVLV1ImageProcessor` works only with image inputs "
"and doesn't process videos anymore. "
"This is a deprecated behavior and will be removed in v5.0. "
"Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
)
videos = make_batched_videos(videos)
pixel_values_videos, vision_grid_thws_videos = [], []
for images in videos:
patches, video_grid_thw = self._preprocess(
images,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
merge_size=merge_size,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
)
pixel_values_videos.extend(patches)
vision_grid_thws_videos.append(video_grid_thw)
data.update(
{
"pixel_values_videos": np.array(pixel_values_videos),
"video_grid_thw": np.array(vision_grid_thws_videos),
}
)
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*):
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = (
images_kwargs["min_pixels"]
if "min_pixels" in images_kwargs
else self.size["shortest_edge"]
)
max_pixels = (
images_kwargs["max_pixels"]
if "max_pixels" in images_kwargs
else self.size["longest_edge"]
)
patch_size = images_kwargs.get("patch_size", self.patch_size)
merge_size = images_kwargs.get("merge_size", self.merge_size)
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * (grid_w + 1) + 2
AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)

View File

@ -43,6 +43,8 @@ class CachedRequestState:
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
xdrope_positions: torch.Tensor | None = None
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None

View File

@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding,
XDRotaryEmbedding,
)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
SupportsMRoPE,
SupportsMultiModal,
SupportsXDRoPE,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
supports_xdrope,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
@ -324,6 +329,7 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.uses_xdrope_dim = model_config.uses_xdrope_dim
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@ -512,6 +518,13 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
# Similar to mrope but use assigned dimension number for RoPE, 4 as default.
self.xdrope_positions = self._make_buffer(
(self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
@ -593,10 +606,14 @@ class GPUModelRunner(
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
@ -772,6 +789,10 @@ class GPUModelRunner(
if self.uses_mrope:
self._init_mrope_positions(req_state)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
@ -987,6 +1008,19 @@ class GPUModelRunner(
)
)
def _init_xdrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
xdrope_model = cast(SupportsXDRoPE, model)
assert req_state.prompt_token_ids is not None, (
"XD-RoPE requires prompt_token_ids to be available."
)
assert supports_xdrope(model), "XD-RoPE support is not implemented."
req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
req_state.prompt_token_ids,
req_state.mm_features,
)
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
@ -1231,6 +1265,11 @@ class GPUModelRunner(
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Calculate XD-RoPE positions.
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
self._calc_xdrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@ -1364,6 +1403,12 @@ class GPUModelRunner(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
elif self.uses_xdrope_dim > 0:
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
@ -1793,6 +1838,53 @@ class GPUModelRunner(
mrope_pos_ptr += completion_part_len
def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
xdrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id]
assert req.xdrope_positions is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds
)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0
assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0:
# prompt's xdrope_positions are pre-computed
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
:, src_start:src_end
]
xdrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
# compute completion's xdrope_positions on-the-fly
dst_start = xdrope_pos_ptr
dst_end = xdrope_pos_ptr + completion_part_len
XDRotaryEmbedding.get_next_input_positions_tensor(
out=self.xdrope_positions.np,
out_offset=dst_start,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
xdrope_pos_ptr += completion_part_len
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
@ -2037,6 +2129,7 @@ class GPUModelRunner(
req_start_idx = 0
should_sync_mrope_positions = False
should_sync_xdrope_positions = False
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
@ -2110,6 +2203,10 @@ class GPUModelRunner(
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_xdrope_positions:
self._calc_xdrope_positions(scheduler_output)
self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module:
@ -2384,8 +2481,11 @@ class GPUModelRunner(
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
@ -3824,6 +3924,8 @@ class GPUModelRunner(
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
else:
positions = self.positions.gpu[:num_tokens_after_padding]