From 92effb07a48e56c531a95b696acd5f699baf16da Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 25 Nov 2025 11:28:51 +0800 Subject: [PATCH] [Model] Add HunyuanOCR support (#29327) Signed-off-by: manayang Signed-off-by: Isotr0py Signed-off-by: Roger Wang Co-authored-by: sergeywang Co-authored-by: manayang Co-authored-by: manayang Co-authored-by: Roger Wang --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 26 + tests/models/registry.py | 4 + vllm/config/model.py | 5 + .../layers/rotary_embedding/__init__.py | 13 + .../layers/rotary_embedding/xdrope.py | 102 ++ vllm/model_executor/models/hunyuan_v1.py | 11 +- vllm/model_executor/models/hunyuan_vision.py | 1028 +++++++++++++++++ vllm/model_executor/models/interfaces.py | 51 +- vllm/model_executor/models/registry.py | 4 + vllm/transformers_utils/config.py | 18 + vllm/transformers_utils/configs/__init__.py | 8 + vllm/transformers_utils/configs/hunyuan_vl.py | 322 ++++++ .../transformers_utils/processors/__init__.py | 10 +- .../processors/hunyuan_vl.py | 233 ++++ .../processors/hunyuan_vl_image.py | 477 ++++++++ vllm/v1/worker/gpu_input_batch.py | 2 + vllm/v1/worker/gpu_model_runner.py | 104 +- 18 files changed, 2415 insertions(+), 4 deletions(-) create mode 100644 vllm/model_executor/layers/rotary_embedding/xdrope.py create mode 100644 vllm/model_executor/models/hunyuan_vision.py create mode 100644 vllm/transformers_utils/configs/hunyuan_vl.py create mode 100644 vllm/transformers_utils/processors/hunyuan_vl.py create mode 100644 vllm/transformers_utils/processors/hunyuan_vl_image.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 404519f887dc6..25579835faf63 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | +| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 624de2a2debc3..65ea4df4a3099 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: ) +# HunyuanOCR +def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "tencent/HunyuanOCR" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + limit_mm_per_prompt={modality: 1}, + ) + + placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501 + prompts = [ + f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=None, + ) + + # naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B def run_hyperclovax_seed_vision( questions: list[str], modality: str @@ -1820,6 +1845,7 @@ model_example_map = { "glm4_5v": run_glm4_5v, "glm4_5v_fp8": run_glm4_5v_fp8, "h2ovl_chat": run_h2ovl, + "hunyuan_vl": run_hunyuan_vl, "hyperclovax_seed_vision": run_hyperclovax_seed_vision, "idefics3": run_idefics3, "interns1": run_interns1, diff --git a/tests/models/registry.py b/tests/models/registry.py index 758ec54493aa3..f8b3470e6d39b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True, ), + "HunYuanVLForConditionalGeneration": _HfExamplesInfo( + "tencent/HunyuanOCR", + is_available_online=False, + ), "Idefics3ForConditionalGeneration": _HfExamplesInfo( "HuggingFaceM4/Idefics3-8B-Llama3", extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, diff --git a/vllm/config/model.py b/vllm/config/model.py index c37dd7c15f2a7..caa9a3440c41d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -33,6 +33,7 @@ from vllm.transformers_utils.config import ( try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope, + uses_xdrope_dim, ) from vllm.transformers_utils.gguf_utils import ( maybe_patch_hf_config_from_gguf, @@ -1615,6 +1616,10 @@ class ModelConfig: def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) + @property + def uses_xdrope_dim(self) -> int: + return uses_xdrope_dim(self.hf_config) + @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 152d9401b8e94..0f10bff6ac4f5 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding from .mrope import MRotaryEmbedding from .ntk_scaling_rope import NTKScalingRotaryEmbedding from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding +from .xdrope import XDRotaryEmbedding from .yarn_scaling_rope import YaRNScalingRotaryEmbedding _ROPE_DICT: dict[tuple, RotaryEmbedding] = {} @@ -184,6 +185,18 @@ def get_rope( raise ValueError( "Dynamic rope scaling must contain either 'alpha' or 'factor' field" ) + elif scaling_type == "xdrope": + scaling_alpha = rope_parameters["alpha"] + rotary_emb = XDRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_alpha, + dtype, + xdrope_section=rope_parameters["xdrope_section"], + ) elif scaling_type == "yarn": scaling_factor = rope_parameters["factor"] original_max_position = rope_parameters["original_max_position_embeddings"] diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py new file mode 100644 index 0000000000000..2432273faf195 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import torch + +from .common import apply_rotary_emb_dispatch +from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding + + +class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): + """DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections. + + Based on the original DynamicNTKAlphaRotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_alpha: float, + dtype: torch.dtype, + xdrope_section: list[int], + ) -> None: + self.xdrope_section = xdrope_section + super().__init__( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + scaling_alpha, + dtype, + ) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [4, num_tokens] (P/W/H/T positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1 + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1 + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_next_input_positions( + context_len: int, + seq_len: int, + xd_sections: int = 4, + ) -> list[list[int]]: + return [list(range(context_len, seq_len)) for _ in range(xd_sections)] + + @staticmethod + def get_next_input_positions_tensor( + out: np.ndarray, + out_offset: int, + context_len: int, + num_new_tokens: int, + ): + values = np.arange( + context_len, + context_len + num_new_tokens, + dtype=out.dtype, + ) + out[:, out_offset : out_offset + num_new_tokens] = values diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 9fa5e2bd33f21..53fb444ed622d 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module): return hidden_states, residual, ori_kv_states -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) class HunYuanModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py new file mode 100644 index 0000000000000..e83addd0c092f --- /dev/null +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -0,0 +1,1028 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# coding=utf-8 +# Copyright 2025 The HunYuan team. +# Copyright 2025 The vLLM team. +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only HunYuan-VL model compatible with HuggingFace weights.""" + +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Literal, TypeAlias + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import BatchFeature + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layer import MultiHeadAttention +from vllm.config import MultiModalConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.hunyuan_vl import ( + HunYuanVLConfig, + HunYuanVLVisionConfig, +) +from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor +from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, + SupportsXDRoPE, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class HunYuanVLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ + + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class HunYuanVLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + """ + + type: Literal["image_embeds"] + + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +HunYuanVLImageInputs: TypeAlias = ( + HunYuanVLImagePixelInputs | HunYuanVLImageEmbeddingInputs +) + +# === Vision Encoder === # + + +class HunYuanVisionMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = True, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense_h_to_4h", + disable_tp=use_data_parallel, + ) + self.dense_4h_to_h = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h", + disable_tp=use_data_parallel, + ) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + x_up, _ = self.dense_h_to_4h(x) + x_down, _ = self.dense_4h_to_h(self.act_fn(x_up)) + return x_down + + +class HunYuanVisionAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size + ) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + + self.o_proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + disable_tp=use_data_parallel, + ) + + self.scale = self.hidden_size_per_attention_head**-0.5 + self.attn = MultiHeadAttention( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + self.scale, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + out = self.attn(q, k, v) + output, _ = self.o_proj(out) + return output + + +class HunYuanVisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.gelu, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.input_layernorm = norm_layer(dim) + self.post_attention_layernorm = norm_layer(dim) + self.self_attn = HunYuanVisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = HunYuanVisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x = x + self.self_attn(self.input_layernorm(x)) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class HunYuanVisionPatchEmbed(nn.Module): + def __init__(self, config: HunYuanVLVisionConfig): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.num_channels = config.num_channels + self.spatial_merge_size = config.spatial_merge_size + self.interpolate_mode = config.interpolate_mode + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + self.max_num_patches = (config.max_image_size // self.patch_size) ** 2 + + self.num_positions = self.max_num_patches + 1 + self.position_edge = int(self.num_positions**0.5) + # first token is cls token, skip it + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + self.patch_pos_embed = None + + def forward( + self, pixel_values: torch.Tensor, grid_thw: list[list[int]] + ) -> torch.Tensor: + num_patches = pixel_values.size(0) + pixel_values = pixel_values.reshape( + num_patches, self.num_channels, self.patch_size, self.patch_size + ) + + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = patch_embeds.squeeze(-1).squeeze(-1).unsqueeze(0) + + if self.patch_pos_embed is None: + patch_pos_shape = ( + 1, + self.position_edge, + self.position_edge, + self.embed_dim, + ) + self.patch_pos_embed = ( + self.position_embedding.weight[1:, :] + .reshape(patch_pos_shape) + .permute(0, 3, 1, 2) + .float() + ) + + patch_pos_embed_list = [] + for grid in grid_thw: + _, h0, w0 = grid + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + self.patch_pos_embed, + scale_factor=(h0 / self.position_edge, w0 / self.position_edge), + mode=self.interpolate_mode, + align_corners=False, + ) + + patch_pos_embed = ( + patch_pos_embed.reshape(self.embed_dim, -1) + .transpose(0, 1) + .unsqueeze(0) + .to(patch_embeds.dtype) + ) + patch_pos_embed_list.append(patch_pos_embed) + + patch_pos_embed = torch.cat(patch_pos_embed_list, dim=1) + embeddings = patch_embeds + patch_pos_embed + + return embeddings + + +class HunYuanVisionPatchMerger(nn.Module): + def __init__( + self, + in_channels, + out_channels, + spatial_merge_size=2, + rms_norm_eps=1e-5, + prefix="", + ): + super().__init__() + self.spatial_merge_size = spatial_merge_size + embed_std = out_channels**-0.5 + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + in_channels * 2, + kernel_size=spatial_merge_size, + stride=spatial_merge_size, + ), + nn.GELU(), + nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1), + ) + self.mlp = nn.Linear(in_channels * 4, out_channels) + + self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std) + self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std) + self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std) + self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std) + + self.before_rms = RMSNorm(in_channels, eps=rms_norm_eps) + self.after_rms = RMSNorm(out_channels, eps=rms_norm_eps) + + def forward(self, x, size=(16, 16)): + x = self.before_rms(x) + + h, w = size + dtype = x.dtype + x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w) + + x = self.proj(x) # b,c,h,w + b, c, h, w = x.shape + x = torch.cat( + [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype)], + dim=-1, + ) + x = x.reshape(b, c, -1).permute(0, 2, 1) + x = self.mlp(x) + + begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype) + end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype) + x = torch.cat([begin, x, end], dim=1) + + return self.after_rms(x) + + +class HunYuanVisionTransformer(nn.Module): + def __init__( + self, + vision_config: HunYuanVLVisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + super().__init__() + + num_hidden_layers = vision_config.num_hidden_layers + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_attention_heads + self.spatial_merge_size = vision_config.spatial_merge_size + + from vllm.compilation.backends import set_model_tag + + with set_model_tag("HunYuanVisionPatchEmbed"): + self.embeddings = HunYuanVisionPatchEmbed(vision_config) + + norm_layer = partial(nn.LayerNorm, eps=vision_config.rms_norm_eps) + + with set_model_tag("HunYuanVisionBlock"): + self.layers = nn.ModuleList( + [ + HunYuanVisionBlock( + dim=vision_config.hidden_size, + num_heads=vision_config.num_attention_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + with set_model_tag("HunYuanVisionPatchMerger"): + self.perceive = HunYuanVisionPatchMerger( + vision_config.hidden_size, + vision_config.out_hidden_size, + spatial_merge_size=vision_config.spatial_merge_size, + rms_norm_eps=vision_config.rms_norm_eps, + prefix=f"{prefix}.perceive", + ) + + @property + def dtype(self) -> torch.dtype: + return self.embeddings.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.embeddings.patch_embedding.weight.device + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + # patchify + seq_len = x.size(0) + cu_seqlens: list = [0] + + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.embeddings(hidden_states, grid_thw) + + for t, h, w in grid_thw: + t, h, w = int(t), int(h), int(w) + cu_seqlens.append(h * w) + + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) + + cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) + + hidden_states = hidden_states.reshape(seq_len, -1) + hidden_states = hidden_states.unsqueeze(0) + for layer_num, layer in enumerate(self.layers): + hidden_states = layer(hidden_states) + + # adapter + split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + split_items = hidden_states.split(split_lengths, dim=1) + image_embeds_list = [] + for grid, split_item in zip(grid_thw, split_items): + image_embeds_list.append( + self.perceive(split_item.contiguous(), size=grid[1:]).squeeze(0) + ) + + return image_embeds_list + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + ) + + +class HunYuanVLMultiModalDataParser(MultiModalDataParser): + def _parse_image_data( + self, + data: dict[str, torch.Tensor] | ModalityData[ImageItem], + ): + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={"image_embeds", "image_grid_thw"}, + fields_factory=_hunyuan_vl_field_config, + ) + + return super()._parse_image_data(data) + + +class HunYuanVLProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(HunYuanVLConfig) + + def get_hf_processor( + self, + **kwargs: object, + ) -> HunYuanVLProcessor: + return self.ctx.get_hf_processor( + HunYuanVLProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_image_processor( + self, + **kwargs: object, + ) -> HunYuanVLProcessor: + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + # TODO: support video + max_video_tokens = 0 + return {"image": max_image_tokens, "video": max_video_tokens} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: HunYuanVLProcessor | None, + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, height=image_height) + + grid_t = 1 + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_vision_tokens = ( + grid_t * grid_h // spatial_merge_size * (grid_w // spatial_merge_size + 1) + + 2 + ) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: HunYuanVLProcessor | None, + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=512, + image_height=8192, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + + +class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 1) + + target_width, target_height = self.info.get_image_size_with_most_features() + + return { + "image": self._get_dummy_images( + width=target_width, height=target_height, num_images=num_images + ), + } + + +class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]): + def _get_data_parser(self) -> MultiModalDataParser: + return HunYuanVLMultiModalDataParser() + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + + placeholder = { + "image": hf_processor.image_token_id, + } + + merge_size = image_processor.merge_size + + def get_replacement_hunyuan_vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + _, grid_h, grid_w = grid_thw + num_tokens = (int(grid_h) // merge_size) * ( + int(grid_w) // merge_size + 1 + ) + 2 + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_hunyuan_vl, modality=modality), + ) + for modality in ("image",) + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _hunyuan_vl_field_config(hf_inputs) + + +@MULTIMODAL_REGISTRY.register_processor( + HunYuanVLMultiModalProcessor, + info=HunYuanVLProcessingInfo, + dummy_inputs=HunYuanVLDummyInputsBuilder, +) +class HunYuanVLForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsLoRA, + SupportsPP, + SupportsQuant, + SupportsXDRoPE, +): + multimodal_cpu_fields = {"image_grid_thw"} + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "vit.vit.": "visual.", + "vit.": "visual.", + "model.": "language_model.model.", + } + ) + + supports_encoder_tp_data = True + + def get_xdrope_input_positions( + self, + input_tokens: list[int], + mm_features: list[MultiModalFeatureSpec], + ) -> torch.Tensor: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + + hf_config = self.config + image_start_token_id = hf_config.image_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + xd_num = len(hf_config.rope_scaling["xdrope_section"]) + + input_tokens_tensor = torch.tensor(input_tokens) + image_start_indices = torch.argwhere( + input_tokens_tensor == image_start_token_id + ).squeeze(1) + + p_index = torch.arange(len(input_tokens_tensor)) + w_index = torch.arange(len(input_tokens_tensor)) + h_index = torch.arange(len(input_tokens_tensor)) + t_index = torch.arange(len(input_tokens_tensor)) + for image_index in range(len(image_start_indices)): + # +1 : first image_token, +2: for xdrope positions + pos = image_start_indices[image_index] + 2 + t, h, w = image_grid_thw[image_index] + _, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + token_num = (llm_grid_w + 1) * llm_grid_h + w_index[pos : pos + token_num].copy_( + torch.arange(0, llm_grid_w + 1) + .reshape(1, -1) + .expand(llm_grid_h, -1) + .reshape(-1) + ) + h_index[pos : pos + token_num].copy_( + torch.arange(0, llm_grid_h) + .reshape(-1, 1) + .expand(-1, llm_grid_w + 1) + .reshape(-1) + ) + h_index[pos : pos + token_num] = 0 + + if xd_num == 4: + llm_positions = torch.stack([p_index, w_index, h_index, t_index]) + elif xd_num == 3: + llm_positions = torch.stack([w_index, h_index, t_index]) + + return llm_positions + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501 + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: HunYuanVLConfig = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + if multimodal_config.get_limit_per_prompt("image"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.visual = HunYuanVisionTransformer( + config.vision_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "visual"), + multimodal_config=multimodal_config, + attn_backend_override=attn_backend_override, + ) + else: + self.visual = None + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model.model"), + architectures=[ + "HunYuanDenseV1ForCausalLM", + "HunYuanMoEV1ForCausalLM", + ], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> HunYuanVLImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + # TODO: refine + if isinstance(pixel_values, list): + pixel_values = torch.cat(pixel_values, dim=0) + if len(pixel_values.shape) == 3: + last_dim = pixel_values.shape[-1] + pixel_values = pixel_values.reshape(-1, last_dim) + image_grid_thw = image_grid_thw.reshape(-1, 3) + + if pixel_values is not None: + return HunYuanVLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return HunYuanVLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _process_image_input( + self, image_input: HunYuanVLImageInputs + ) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"] + + # TODO: use_data_parallel (split image_embeds in visual) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + return image_embeds + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) + return multimodal_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model.model", + connector="visual.perceive", + tower_model="visual", + ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9966498e1b4c9..6f6ce32538b71 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol): supports_mrope: ClassVar[Literal[True]] = True """ A flag that indicates this model supports M-RoPE. - + Note: There is no need to redefine this flag if this class is in the MRO of your model class. @@ -1088,3 +1088,52 @@ def supports_mrope( model: type[object] | object, ) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]: return isinstance(model, SupportsMRoPE) + + +@runtime_checkable +class SupportsXDRoPE(Protocol): + """The interface required for all models that support XD-RoPE.""" + + supports_xdrope: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports XD-RoPE. + + Note: + There is no need to redefine this flag if this class is in the + XDRope of your model class. + """ + + def get_xdrope_input_positions( + self, + input_tokens: list[int], + mm_features: list["MultiModalFeatureSpec"], + ) -> torch.Tensor: + """ + Get XD-RoPE input positions and delta value for this specific model. + + This method should be implemented by each model that supports XD-RoPE + to provide model-specific logic for computing input positions. + + Args: + input_tokens: List of input token IDs + mm_features: Information about each multi-modal data item + + Returns: + llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with + 4D(P/W/H/T) or 3D(W/H/T) positions. + """ + ... + + +@overload +def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ... + + +@overload +def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ... + + +def supports_xdrope( + model: type[object] | object, +) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]: + return isinstance(model, SupportsXDRoPE) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b3da64af750c7..a0d8a78a2ae76 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -287,6 +287,10 @@ _MULTIMODAL_MODELS = { "GraniteSpeechForConditionalGeneration", ), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), + "HunYuanVLForConditionalGeneration": ( + "hunyuan_vision", + "HunYuanVLForConditionalGeneration", + ), "InternVLChatModel": ("internvl", "InternVLChatModel"), "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), "OpenCUAForConditionalGeneration": ( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3d282da8c6112..c1880a3fba0ee 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -86,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32="DeepseekV3Config", flex_olmo="FlexOlmoConfig", + hunyuan_vl="HunYuanVLConfig", kimi_linear="KimiLinearConfig", kimi_vl="KimiVLConfig", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) @@ -549,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool: return uses_mrope(thinker_text_config) +def uses_xdrope_dim(config: PretrainedConfig) -> int: + """Detect if the model with this config uses XD-ROPE.""" + xdrope_section = getattr(config, "xdrope_section", None) + if xdrope_section is not None and isinstance(xdrope_section, list): + return len(xdrope_section) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return 0 + + if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling: + xdrope_section = rope_scaling["xdrope_section"] + if xdrope_section is not None and isinstance(xdrope_section, list): + return len(xdrope_section) + + return 0 + + def is_encoder_decoder(config: PretrainedConfig) -> bool: """Detect if the model with this config is used as an encoder/decoder.""" diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d28fd8d033373..109f2b6986514 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig +from vllm.transformers_utils.configs.hunyuan_vl import ( + HunYuanVLConfig, + HunYuanVLTextConfig, + HunYuanVLVisionConfig, +) from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig @@ -53,6 +58,9 @@ __all__ = [ "DotsOCRConfig", "EAGLEConfig", "FlexOlmoConfig", + "HunYuanVLConfig", + "HunYuanVLTextConfig", + "HunYuanVLVisionConfig", "RWConfig", "JAISConfig", "Lfm2MoeConfig", diff --git a/vllm/transformers_utils/configs/hunyuan_vl.py b/vllm/transformers_utils/configs/hunyuan_vl.py new file mode 100644 index 0000000000000..a826ed9b5155d --- /dev/null +++ b/vllm/transformers_utils/configs/hunyuan_vl.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py + +from transformers import PretrainedConfig + + +class HunYuanVLVisionConfig(PretrainedConfig): + model_type = "hunyuan_vl" + base_config_key = "vision_config" + + def __init__( + self, + hidden_act="gelu", + hidden_size=1152, + intermediate_size=4304, + interpolate_mode="bilinear", + rms_norm_eps=1e-05, + learnable_mlp_pooling_size=0, + num_attention_heads=16, + num_key_value_heads=None, + num_channels=3, + num_hidden_layers=27, + out_hidden_size=4096, + patch_size=16, + remove_prenorm=True, + spatial_merge_size=2, + temporal_patch_size=1, + resize_resolution=2048, + img_max_token_num=4096, + max_image_size=2048, + video_max_image_size=768, + video_min_image_size=256, + min_image_size=512, + anyres_vit_max_image_size=2048, + max_vit_seq_len=16384, + text_hidden_size=3072, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.interpolate_mode = interpolate_mode + self.learnable_mlp_pooling_size = learnable_mlp_pooling_size + self.num_attention_heads = num_attention_heads + if not num_key_value_heads: + self.num_key_value_heads = num_attention_heads + else: + self.num_key_value_heads = num_key_value_heads + self.num_channels = num_channels + self.num_hidden_layers = num_hidden_layers + self.out_hidden_size = out_hidden_size + self.patch_size = patch_size + self.remove_prenorm = remove_prenorm + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + + self.resize_resolution = resize_resolution + self.img_max_token_num = img_max_token_num + self.max_image_size = max_image_size + self.min_image_size = min_image_size + self.video_max_image_size = video_max_image_size + self.video_min_image_size = video_min_image_size + self.anyres_vit_max_image_size = anyres_vit_max_image_size + self.max_vit_seq_len = max_vit_seq_len + self.text_hidden_size = text_hidden_size + + +class HunYuanVLTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an + HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the HunYuan-7B. + Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 290943): + Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`HunYuanVLTextConfig`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations or shared MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + eod_token_id (int, *optional*, defaults to 3): + Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence. + Example: In multi-document processing, this token helps the model distinguish between separate documents. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + """ # noqa: E501 + + model_type = "hunyuan_vl_text" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=290943, + hidden_size=4096, + intermediate_size: int = 11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + eod_token_id=3, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # self._rope_scaling_validation() # TODO: Need validation? + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and " + f"`factor` or `type` and `alpha`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + rope_scaling_alpha = self.rope_scaling.get("alpha", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], " + f"got {rope_scaling_type}" + ) + if rope_scaling_factor is None and rope_scaling_alpha is None: + raise ValueError( + "`rope_scaling`'s factor or alpha field must be have one, " + "got both of none" + ) + if rope_scaling_factor is not None and ( + not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0 + ): + raise ValueError( + "`rope_scaling`'s factor field must be a float > 1.0, " + f"got {rope_scaling_factor}" + ) + if rope_scaling_alpha is not None and ( + not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0 + ): + raise ValueError( + "`rope_scaling`'s alpha field must be a float > 1.0, " + f"got {rope_scaling_alpha}" + ) + + +class HunYuanVLConfig(PretrainedConfig): + model_type = "hunyuan_vl" + sub_configs = { + "vision_config": HunYuanVLVisionConfig, + "text_config": HunYuanVLTextConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + im_start_id=120118, + im_end_id=120119, + image_token_id=120120, + im_newline_id=120121, + video_start_id=120122, + video_end_id=120123, + **kwargs, + ): + # We need to init super() here so that it does not reset values + # that are in text config to the BaseClass defaults. The Base + # config has many text related defaults and not all defaults are + # same as for `HunYuanVLTextConfig`. + super().__init__(**kwargs) + + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + # For BC use all kwargs to init `TextConfig` + self.text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.im_start_id = im_start_id + self.im_end_id = im_end_id + self.im_newline_id = im_newline_id + self.video_start_id = video_start_id + self.video_end_id = video_end_id + + self.vision_config.text_hidden_size = self.text_config.hidden_size + + # Attention implementation to use. It sets it recursively on sub-configs + # so we call it again in the end. + self._attn_implementation = kwargs.pop("attn_implementation", None) + + def __setattr__(self, key, value): + if ( + (text_config := super().__getattribute__("__dict__").get("text_config")) + is not None + and key not in ["dtype", "_attn_implementation_internal"] + and key in text_config.__dict__ + ): + setattr(text_config, key, value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key): + if "text_config" in super().__getattribute__("__dict__") and key not in [ + "_name_or_path", + "model_type", + "dtype", + "_attn_implementation_internal", + ]: + text_config = super().__getattribute__("text_config") + if key in text_config.__dict__: + return getattr(text_config, key) + + return super().__getattribute__(key) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 76b6d3dc9c99a..b49fdbe9ce776 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -9,7 +9,15 @@ reasons: """ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor +from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor +from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor -__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"] +__all__ = [ + "DeepseekVLV2Processor", + "HunYuanVLProcessor", + "HunYuanVLImageProcessor", + "OvisProcessor", + "Ovis2_5Processor", +] diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py new file mode 100644 index 0000000000000..615a8bff85912 --- /dev/null +++ b/vllm/transformers_utils/processors/hunyuan_vl.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py + +import numpy as np +import torch +from transformers import AutoProcessor +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.video_utils import VideoInput + + +class HunYuanVLProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None) + + def __init__( + self, + image_processor=None, + tokenizer=None, + video_processor=None, + chat_template=None, + **kwargs, + ): + # TODO Fix the init + self.tokenizer = tokenizer + self.image_token_id = 120120 # self.tokenizer.image_token_id + self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id) + self.im_start_token_id = 120118 # self.tokenizer.im_start_id + self.im_start_token = self.tokenizer.convert_ids_to_tokens( + self.im_start_token_id + ) + self.im_end_token_id = 120119 # self.tokenizer.im_end_id + self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id) + self.placeholder_token = self.tokenizer.convert_ids_to_tokens( + self.tokenizer.vocab_size - 1 + ) + self.pad_id = 120002 # self.tokenizer.pad_token_id + + super().__init__( + image_processor, tokenizer, video_processor, chat_template=chat_template + ) + + def __call__( + self, + images: ImageInput = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, + videos: VideoInput = None, + **kwargs, + ) -> BatchFeature: + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images=images) + image_grid_thw = image_inputs["image_grid_thw"] + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + + image_tokens_cumsum = [0] + if images is not None: + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + grid_h, grid_w = image_grid_thw[index][-2:] + patch_h = grid_h // self.image_processor.merge_size + patch_w = grid_w // self.image_processor.merge_size + num_image_tokens = patch_h * (patch_w + 1) + 2 + image_tokens_cumsum.append( + image_tokens_cumsum[-1] + num_image_tokens + ) + # text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501 + text[i] = text[i].replace( + self.image_token, self.placeholder_token * num_image_tokens, 1 + ) + index += 1 + text[i] = text[i].replace(self.placeholder_token, self.image_token) + # text[i] = self.tokenizer.bos_token + text[i] + + text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + input_ids = text_inputs["input_ids"] + position_ids = torch.arange(len(input_ids[0])) + position_ids_w = torch.arange(len(input_ids[0])) + position_ids_h = torch.arange(len(input_ids[0])) + position_ids_t = torch.arange(len(input_ids[0])) + + if images is not None: + image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[ + 0 + ] + for i in range(len(image_grid_thw)): + grid_h, grid_w = image_grid_thw[i][-2:] + patch_h = grid_h // self.image_processor.merge_size + patch_w = grid_w // self.image_processor.merge_size + start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1 + replace_num = (patch_w + 1) * patch_h + position_ids_w[start_pos : start_pos + replace_num] = torch.tensor( + list(range(patch_w + 1)) * patch_h, dtype=torch.int64 + ) + patch_h_list = [] + for h in range(patch_h): + patch_h_list += [h] * (patch_w + 1) + position_ids_h[start_pos : start_pos + replace_num] = torch.tensor( + patch_h_list, dtype=torch.int64 + ) + position_ids_t[start_pos : start_pos + replace_num] = 0 + + position_ids = torch.stack( + [position_ids, position_ids_w, position_ids_h, position_ids_t] + ).unsqueeze(0) + text_inputs["position_ids"] = position_ids + + attention_mask = input_ids.ne(self.pad_id) + text_inputs["attention_mask"] = attention_mask + text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)] + # image_inputs["imgs"] = [[image_inputs["pixel_values"]]] + + return_tensors = kwargs.pop("return_tensors", None) + return BatchFeature( + data={**text_inputs, **image_inputs}, + tensor_type=return_tensors, + ) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + **kwargs, + ): + assert 0 + + def apply_chat_template(self, *args, **kwargs): + token_ids = self.tokenizer.apply_chat_template(*args, **kwargs) + return token_ids + + def get_imgs_pos(self, doc_ids): + doc_ids = np.array(doc_ids, dtype=np.int64) + img_begin_index = np.where(doc_ids == self.im_start_token_id)[0] + img_end_index = np.where(doc_ids == self.im_end_token_id)[0] + imgs_pos = np.concatenate( + ( + np.reshape(img_begin_index + 1, (-1, 1)), + np.reshape(img_end_index, (-1, 1)), + ), + axis=-1, + ).tolist() + return imgs_pos + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +def split_image_into_patch_blocks( + pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W] + patch_size: int = 16, # e.g. 16 + adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501 +) -> torch.Tensor: + """ + Split the input image tensor (supporting batch) into large patches of size `patch_size`, + and then further divide each large patch into smaller regions of size + (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div). + Each small region is extracted as a tensor of shape [3, patch_size, patch_size]. + The final output contains all such small region tensors. + + Args: + pixel_values: Input image tensor of shape [batch_size, 3, H, W]. + patch_size: Size of the large patch, e.g., 16. + adaptor_patch_div: Each large patch is divided into + (patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div) + smaller regions. + + Returns: + patches: A tensor of shape [N, 3, patch_size, patch_size], + where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2. + Each element in the batch corresponds to one small image region. + """ # noqa: E501 + batch_size, channels, height, width = pixel_values.shape + assert channels == 3, "Pixel values must have 3 channels in dim=1" + assert height % patch_size == 0 and width % patch_size == 0, ( + "H and W must be divisible by patch_size" + ) + + patch_height_num = height // patch_size + patch_width_num = width // patch_size + + # Reshape to [B, 3, ph, ps, pw, ps] + img = pixel_values.reshape( + batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size + ) + + # Further split each psxps patch into (ps//aps)x(ps//aps) small regions + img = img.reshape( + batch_size, + 3, + patch_height_num, + patch_size // adaptor_patch_div, # ps // aps + adaptor_patch_div, + patch_width_num, + patch_size // adaptor_patch_div, # ps // aps + adaptor_patch_div, + ) + + # Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps] + img = img.permute(0, 2, 5, 3, 6, 1, 4, 7) + + # Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size] + patches = img.reshape(-1, 3, patch_size, patch_size) + + return patches + + +AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor) diff --git a/vllm/transformers_utils/processors/hunyuan_vl_image.py b/vllm/transformers_utils/processors/hunyuan_vl_image.py new file mode 100644 index 0000000000000..0a7e7865c783a --- /dev/null +++ b/vllm/transformers_utils/processors/hunyuan_vl_image.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py +"""Image processor class for HunYuanVL.""" + +# isort conflicts with ruff for transformers imports +# isort: skip_file +import math + +import numpy as np +import torchvision.transforms as transforms +from transformers import AutoImageProcessor +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_transforms import ( + convert_to_rgb, +) +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + make_flat_list_of_images, + make_list_of_images, + valid_images, + validate_preprocess_arguments, +) +from transformers.utils import TensorType, logging +from transformers.video_utils import VideoInput, make_batched_videos + +logger = logging.get_logger(__name__) + + +def smart_resize( + height: int, + width: int, + factor: int = 16, + min_pixels: int = 512 * 512, + max_pixels: int = 2048 * 2048, +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class HunYuanVLImageProcessor(BaseImageProcessor): + model_input_names = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + ] + + def __init__( + self, + do_resize: bool = True, + size: dict[str, int] | None = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: int | float = 1 / 255, + do_normalize: bool = True, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool = True, + min_pixels: int | None = None, + max_pixels: int | None = None, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None and ( + "shortest_edge" not in size or "longest_edge" not in size + ): + raise ValueError( + "size must contain 'shortest_edge' and 'longest_edge' keys." + ) + else: + size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048} + # backward compatibility: override size with min_pixels and max_pixels + # if they are provided. + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + # hard-code + + def _preprocess( + self, + images: ImageInput | VideoInput, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + resample: PILImageResampling = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + do_convert_rgb: bool | None = None, + data_format: ChannelDimension | None = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ # noqa: E501 + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + width, height = images[0].width, images[0].height + resized_width, resized_height = width, height + processed_images = [] + for image in images: + if do_resize: + resized_width, resized_height = smart_resize( + width, + height, + factor=patch_size * merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + if do_normalize: + image = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(self.image_mean, self.image_std), + ] + )(image) + processed_images.append(image) + + patches = np.array(processed_images) + channel = patches.shape[1] + grid_t = patches.shape[0] // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + patches = patches.reshape( + 1, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7) + flatten_patches = patches.reshape( + 1 * grid_h * grid_w, channel * patch_size * patch_size + ) + + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + videos: VideoInput = None, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + min_pixels: int | None = None, + max_pixels: int | None = None, + resample: PILImageResampling = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + patch_size: int | None = None, + temporal_patch_size: int | None = None, + merge_size: int | None = None, + do_convert_rgb: bool | None = None, + return_tensors: str | TensorType | None = None, + data_format: ChannelDimension | None = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`VideoInput`): + Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If + passing in videos with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + min_pixels (`int`, *optional*, defaults to `self.min_pixels`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `self.max_pixels`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ # noqa: E501 + min_pixels = min_pixels if min_pixels is not None else self.min_pixels + max_pixels = max_pixels if max_pixels is not None else self.max_pixels + + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError( + "size must contain 'shortest_edge' and 'longest_edge' keys." + ) + min_pixels = size["shortest_edge"] + elif min_pixels is not None and max_pixels is not None: + # backward compatibility: override size with min_pixels and max_pixels + # if they are provided. + size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} + else: + size = {**self.size} + + do_resize = do_resize if do_resize is not None else self.do_resize + + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = ( + rescale_factor if rescale_factor is not None else self.rescale_factor + ) + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = ( + temporal_patch_size + if temporal_patch_size is not None + else self.temporal_patch_size + ) + merge_size = merge_size if merge_size is not None else self.merge_size + do_convert_rgb = ( + do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + ) + + if images is not None: + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + data = {} + if images is not None: + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + merge_size=merge_size, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data.update( + {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} + ) + + # kept for BC only and should be removed after v5.0 + if videos is not None: + logger.warning( + "`HunYuanVLV1ImageProcessor` works only with image inputs " + "and doesn't process videos anymore. " + "This is a deprecated behavior and will be removed in v5.0. " + "Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. " + ) + videos = make_batched_videos(videos) + pixel_values_videos, vision_grid_thws_videos = [], [] + for images in videos: + patches, video_grid_thw = self._preprocess( + images, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + merge_size=merge_size, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values_videos.extend(patches) + vision_grid_thws_videos.append(video_grid_thw) + data.update( + { + "pixel_values_videos": np.array(pixel_values_videos), + "video_grid_thw": np.array(vision_grid_thws_videos), + } + ) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*): + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of image patches per image. + """ + min_pixels = ( + images_kwargs["min_pixels"] + if "min_pixels" in images_kwargs + else self.size["shortest_edge"] + ) + max_pixels = ( + images_kwargs["max_pixels"] + if "max_pixels" in images_kwargs + else self.size["longest_edge"] + ) + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + + factor = patch_size * merge_size + resized_height, resized_width = smart_resize( + height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * (grid_w + 1) + 2 + + +AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 4a2818ab1bfd8..e7991baeaa1b8 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -43,6 +43,8 @@ class CachedRequestState: mrope_positions: torch.Tensor | None = None mrope_position_delta: int | None = None + xdrope_positions: torch.Tensor | None = None + lora_request: LoRARequest | None = None prompt_embeds: torch.Tensor | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6a83ac14e0b3f..6413be66b141c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import ( + MRotaryEmbedding, + XDRotaryEmbedding, +) from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import ( SupportsMRoPE, SupportsMultiModal, + SupportsXDRoPE, is_mixture_of_experts, supports_eagle3, supports_mrope, supports_multimodal_pruning, supports_transcription, + supports_xdrope, ) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, @@ -324,6 +329,7 @@ class GPUModelRunner( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.uses_xdrope_dim = model_config.uses_xdrope_dim self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config ) @@ -512,6 +518,13 @@ class GPUModelRunner( (3, self.max_num_tokens + 1), dtype=torch.int64 ) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + # Similar to mrope but use assigned dimension number for RoPE, 4 as default. + self.xdrope_positions = self._make_buffer( + (self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64 + ) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None @@ -593,10 +606,14 @@ class GPUModelRunner( if isinstance(num_tokens, int): if self.uses_mrope: return self.mrope_positions.gpu[:, :num_tokens] + if self.uses_xdrope_dim > 0: + return self.xdrope_positions.gpu[:, :num_tokens] return self.positions.gpu[:num_tokens] else: if self.uses_mrope: return self.mrope_positions.gpu[:, num_tokens] + if self.uses_xdrope_dim > 0: + return self.xdrope_positions.gpu[:, num_tokens] return self.positions.gpu[num_tokens] def _make_buffer( @@ -772,6 +789,10 @@ class GPUModelRunner( if self.uses_mrope: self._init_mrope_positions(req_state) + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._init_xdrope_positions(req_state) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. @@ -987,6 +1008,19 @@ class GPUModelRunner( ) ) + def _init_xdrope_positions(self, req_state: CachedRequestState): + model = self.get_model() + xdrope_model = cast(SupportsXDRoPE, model) + assert req_state.prompt_token_ids is not None, ( + "XD-RoPE requires prompt_token_ids to be available." + ) + assert supports_xdrope(model), "XD-RoPE support is not implemented." + + req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions( + req_state.prompt_token_ids, + req_state.mm_features, + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", @@ -1231,6 +1265,11 @@ class GPUModelRunner( if self.uses_mrope: self._calc_mrope_positions(scheduler_output) + # Calculate XD-RoPE positions. + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + if self.uses_xdrope_dim > 0: + self._calc_xdrope_positions(scheduler_output) + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -1364,6 +1403,12 @@ class GPUModelRunner( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True, ) + elif self.uses_xdrope_dim > 0: + # Only relevant for models using XD-RoPE (e.g, HunYuan-VL) + self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) @@ -1793,6 +1838,53 @@ class GPUModelRunner( mrope_pos_ptr += completion_part_len + def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"): + xdrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.xdrope_positions is not None + + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds + ) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's xdrope_positions are pre-computed + dst_start = xdrope_pos_ptr + dst_end = xdrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[ + :, src_start:src_end + ] + xdrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's xdrope_positions on-the-fly + dst_start = xdrope_pos_ptr + dst_end = xdrope_pos_ptr + completion_part_len + + XDRotaryEmbedding.get_next_input_positions_tensor( + out=self.xdrope_positions.np, + out_offset=dst_start, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + xdrope_pos_ptr += completion_part_len + def _calc_spec_decode_metadata( self, num_draft_tokens: np.ndarray, @@ -2037,6 +2129,7 @@ class GPUModelRunner( req_start_idx = 0 should_sync_mrope_positions = False + should_sync_xdrope_positions = False for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] @@ -2110,6 +2203,10 @@ class GPUModelRunner( self._calc_mrope_positions(scheduler_output) self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) + if should_sync_xdrope_positions: + self._calc_xdrope_positions(scheduler_output) + self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens) + return mm_embeds, is_mm_embed def get_model(self) -> nn.Module: @@ -2384,8 +2481,11 @@ class GPUModelRunner( input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) + if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_input_tokens] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_input_tokens] else: positions = self.positions.gpu[:num_input_tokens] @@ -3824,6 +3924,8 @@ class GPUModelRunner( if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens_after_padding] + elif self.uses_xdrope_dim > 0: + positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding] else: positions = self.positions.gpu[:num_tokens_after_padding]