mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 15:44:30 +08:00
[Model] Add HunyuanOCR support (#29327)
Signed-off-by: manayang <jackmanayang@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: sergeywang <sergeywang@tencent.com> Co-authored-by: manayang <jackmanayang@gmail.com> Co-authored-by: manayang <manayang@tencent.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
87185c88d5
commit
92effb07a4
@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
|
||||
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | |
|
||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -538,6 +538,31 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# HunyuanOCR
|
||||
def run_hunyuan_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "tencent/HunyuanOCR"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
|
||||
prompts = [
|
||||
f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=None,
|
||||
)
|
||||
|
||||
|
||||
# naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
|
||||
def run_hyperclovax_seed_vision(
|
||||
questions: list[str], modality: str
|
||||
@ -1820,6 +1845,7 @@ model_example_map = {
|
||||
"glm4_5v": run_glm4_5v,
|
||||
"glm4_5v_fp8": run_glm4_5v_fp8,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"hunyuan_vl": run_hunyuan_vl,
|
||||
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
|
||||
"idefics3": run_idefics3,
|
||||
"interns1": run_interns1,
|
||||
|
||||
@ -626,6 +626,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"HunYuanVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"tencent/HunyuanOCR",
|
||||
is_available_online=False,
|
||||
),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.transformers_utils.config import (
|
||||
try_get_safetensors_metadata,
|
||||
try_get_tokenizer_config,
|
||||
uses_mrope,
|
||||
uses_xdrope_dim,
|
||||
)
|
||||
from vllm.transformers_utils.gguf_utils import (
|
||||
maybe_patch_hf_config_from_gguf,
|
||||
@ -1615,6 +1616,10 @@ class ModelConfig:
|
||||
def uses_mrope(self) -> bool:
|
||||
return uses_mrope(self.hf_config)
|
||||
|
||||
@property
|
||||
def uses_xdrope_dim(self) -> int:
|
||||
return uses_xdrope_dim(self.hf_config)
|
||||
|
||||
@property
|
||||
def is_multimodal_model(self) -> bool:
|
||||
return self.multimodal_config is not None
|
||||
|
||||
@ -17,6 +17,7 @@ from .llama4_vision_rope import Llama4VisionRotaryEmbedding
|
||||
from .mrope import MRotaryEmbedding
|
||||
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
|
||||
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
|
||||
from .xdrope import XDRotaryEmbedding
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
|
||||
@ -184,6 +185,18 @@ def get_rope(
|
||||
raise ValueError(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "xdrope":
|
||||
scaling_alpha = rope_parameters["alpha"]
|
||||
rotary_emb = XDRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
xdrope_section=rope_parameters["xdrope_section"],
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
|
||||
102
vllm/model_executor/layers/rotary_embedding/xdrope.py
Normal file
102
vllm/model_executor/layers/rotary_embedding/xdrope.py
Normal file
@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
|
||||
class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
"""DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
|
||||
|
||||
Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_alpha: float,
|
||||
dtype: torch.dtype,
|
||||
xdrope_section: list[int],
|
||||
) -> None:
|
||||
self.xdrope_section = xdrope_section
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = torch.cat(
|
||||
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
xd_sections: int = 4,
|
||||
) -> list[list[int]]:
|
||||
return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions_tensor(
|
||||
out: np.ndarray,
|
||||
out_offset: int,
|
||||
context_len: int,
|
||||
num_new_tokens: int,
|
||||
):
|
||||
values = np.arange(
|
||||
context_len,
|
||||
context_len + num_new_tokens,
|
||||
dtype=out.dtype,
|
||||
)
|
||||
out[:, out_offset : out_offset + num_new_tokens] = values
|
||||
@ -576,7 +576,16 @@ class HunYuanDecoderLayer(nn.Module):
|
||||
return hidden_states, residual, ori_kv_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
# positions is of shape (xd, seq_len) if xdrope is enabled for hunyuan-vl,
|
||||
# otherwise (seq_len, ).
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
)
|
||||
class HunYuanModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
1028
vllm/model_executor/models/hunyuan_vision.py
Normal file
1028
vllm/model_executor/models/hunyuan_vision.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1047,7 +1047,7 @@ class SupportsMRoPE(Protocol):
|
||||
supports_mrope: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports M-RoPE.
|
||||
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
@ -1088,3 +1088,52 @@ def supports_mrope(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
|
||||
return isinstance(model, SupportsMRoPE)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsXDRoPE(Protocol):
|
||||
"""The interface required for all models that support XD-RoPE."""
|
||||
|
||||
supports_xdrope: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports XD-RoPE.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
XDRope of your model class.
|
||||
"""
|
||||
|
||||
def get_xdrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list["MultiModalFeatureSpec"],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get XD-RoPE input positions and delta value for this specific model.
|
||||
|
||||
This method should be implemented by each model that supports XD-RoPE
|
||||
to provide model-specific logic for computing input positions.
|
||||
|
||||
Args:
|
||||
input_tokens: List of input token IDs
|
||||
mm_features: Information about each multi-modal data item
|
||||
|
||||
Returns:
|
||||
llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
|
||||
4D(P/W/H/T) or 3D(W/H/T) positions.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...
|
||||
|
||||
|
||||
def supports_xdrope(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
|
||||
return isinstance(model, SupportsXDRoPE)
|
||||
|
||||
@ -287,6 +287,10 @@ _MULTIMODAL_MODELS = {
|
||||
"GraniteSpeechForConditionalGeneration",
|
||||
),
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"HunYuanVLForConditionalGeneration": (
|
||||
"hunyuan_vision",
|
||||
"HunYuanVLForConditionalGeneration",
|
||||
),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
|
||||
"OpenCUAForConditionalGeneration": (
|
||||
|
||||
@ -86,6 +86,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_vl_v2="DeepseekVLV2Config",
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
flex_olmo="FlexOlmoConfig",
|
||||
hunyuan_vl="HunYuanVLConfig",
|
||||
kimi_linear="KimiLinearConfig",
|
||||
kimi_vl="KimiVLConfig",
|
||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||
@ -549,6 +550,23 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool:
|
||||
return uses_mrope(thinker_text_config)
|
||||
|
||||
|
||||
def uses_xdrope_dim(config: PretrainedConfig) -> int:
|
||||
"""Detect if the model with this config uses XD-ROPE."""
|
||||
xdrope_section = getattr(config, "xdrope_section", None)
|
||||
if xdrope_section is not None and isinstance(xdrope_section, list):
|
||||
return len(xdrope_section)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is None:
|
||||
return 0
|
||||
|
||||
if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
|
||||
xdrope_section = rope_scaling["xdrope_section"]
|
||||
if xdrope_section is not None and isinstance(xdrope_section, list):
|
||||
return len(xdrope_section)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||
"""Detect if the model with this config is used as an encoder/decoder."""
|
||||
|
||||
|
||||
@ -23,6 +23,11 @@ from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
|
||||
from vllm.transformers_utils.configs.hunyuan_vl import (
|
||||
HunYuanVLConfig,
|
||||
HunYuanVLTextConfig,
|
||||
HunYuanVLVisionConfig,
|
||||
)
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
@ -53,6 +58,9 @@ __all__ = [
|
||||
"DotsOCRConfig",
|
||||
"EAGLEConfig",
|
||||
"FlexOlmoConfig",
|
||||
"HunYuanVLConfig",
|
||||
"HunYuanVLTextConfig",
|
||||
"HunYuanVLVisionConfig",
|
||||
"RWConfig",
|
||||
"JAISConfig",
|
||||
"Lfm2MoeConfig",
|
||||
|
||||
322
vllm/transformers_utils/configs/hunyuan_vl.py
Normal file
322
vllm/transformers_utils/configs/hunyuan_vl.py
Normal file
@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/configuration_hunyuan_vl.py
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class HunYuanVLVisionConfig(PretrainedConfig):
|
||||
model_type = "hunyuan_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_act="gelu",
|
||||
hidden_size=1152,
|
||||
intermediate_size=4304,
|
||||
interpolate_mode="bilinear",
|
||||
rms_norm_eps=1e-05,
|
||||
learnable_mlp_pooling_size=0,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=None,
|
||||
num_channels=3,
|
||||
num_hidden_layers=27,
|
||||
out_hidden_size=4096,
|
||||
patch_size=16,
|
||||
remove_prenorm=True,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=1,
|
||||
resize_resolution=2048,
|
||||
img_max_token_num=4096,
|
||||
max_image_size=2048,
|
||||
video_max_image_size=768,
|
||||
video_min_image_size=256,
|
||||
min_image_size=512,
|
||||
anyres_vit_max_image_size=2048,
|
||||
max_vit_seq_len=16384,
|
||||
text_hidden_size=3072,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.learnable_mlp_pooling_size = learnable_mlp_pooling_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if not num_key_value_heads:
|
||||
self.num_key_value_heads = num_attention_heads
|
||||
else:
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_channels = num_channels
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.remove_prenorm = remove_prenorm
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.resize_resolution = resize_resolution
|
||||
self.img_max_token_num = img_max_token_num
|
||||
self.max_image_size = max_image_size
|
||||
self.min_image_size = min_image_size
|
||||
self.video_max_image_size = video_max_image_size
|
||||
self.video_min_image_size = video_min_image_size
|
||||
self.anyres_vit_max_image_size = anyres_vit_max_image_size
|
||||
self.max_vit_seq_len = max_vit_seq_len
|
||||
self.text_hidden_size = text_hidden_size
|
||||
|
||||
|
||||
class HunYuanVLTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`HunYuanVLTextConfig`]. It is used to instantiate an
|
||||
HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the HunYuan-7B.
|
||||
Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 290943):
|
||||
Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`HunYuanVLTextConfig`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations or shared MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
eod_token_id (int, *optional*, defaults to 3):
|
||||
Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
|
||||
Example: In multi-document processing, this token helps the model distinguish between separate documents.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
""" # noqa: E501
|
||||
|
||||
model_type = "hunyuan_vl_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=290943,
|
||||
hidden_size=4096,
|
||||
intermediate_size: int = 11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
eod_token_id=3,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# self._rope_scaling_validation() # TODO: Need validation?
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and "
|
||||
f"`factor` or `type` and `alpha`, got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
rope_scaling_alpha = self.rope_scaling.get("alpha", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], "
|
||||
f"got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None and rope_scaling_alpha is None:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor or alpha field must be have one, "
|
||||
"got both of none"
|
||||
)
|
||||
if rope_scaling_factor is not None and (
|
||||
not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor field must be a float > 1.0, "
|
||||
f"got {rope_scaling_factor}"
|
||||
)
|
||||
if rope_scaling_alpha is not None and (
|
||||
not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s alpha field must be a float > 1.0, "
|
||||
f"got {rope_scaling_alpha}"
|
||||
)
|
||||
|
||||
|
||||
class HunYuanVLConfig(PretrainedConfig):
|
||||
model_type = "hunyuan_vl"
|
||||
sub_configs = {
|
||||
"vision_config": HunYuanVLVisionConfig,
|
||||
"text_config": HunYuanVLTextConfig,
|
||||
}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
im_start_id=120118,
|
||||
im_end_id=120119,
|
||||
image_token_id=120120,
|
||||
im_newline_id=120121,
|
||||
video_start_id=120122,
|
||||
video_end_id=120123,
|
||||
**kwargs,
|
||||
):
|
||||
# We need to init super() here so that it does not reset values
|
||||
# that are in text config to the BaseClass defaults. The Base
|
||||
# config has many text related defaults and not all defaults are
|
||||
# same as for `HunYuanVLTextConfig`.
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
# For BC use all kwargs to init `TextConfig`
|
||||
self.text_config = self.sub_configs["text_config"](**kwargs)
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.im_start_id = im_start_id
|
||||
self.im_end_id = im_end_id
|
||||
self.im_newline_id = im_newline_id
|
||||
self.video_start_id = video_start_id
|
||||
self.video_end_id = video_end_id
|
||||
|
||||
self.vision_config.text_hidden_size = self.text_config.hidden_size
|
||||
|
||||
# Attention implementation to use. It sets it recursively on sub-configs
|
||||
# so we call it again in the end.
|
||||
self._attn_implementation = kwargs.pop("attn_implementation", None)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if (
|
||||
(text_config := super().__getattribute__("__dict__").get("text_config"))
|
||||
is not None
|
||||
and key not in ["dtype", "_attn_implementation_internal"]
|
||||
and key in text_config.__dict__
|
||||
):
|
||||
setattr(text_config, key, value)
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattribute__(self, key):
|
||||
if "text_config" in super().__getattribute__("__dict__") and key not in [
|
||||
"_name_or_path",
|
||||
"model_type",
|
||||
"dtype",
|
||||
"_attn_implementation_internal",
|
||||
]:
|
||||
text_config = super().__getattribute__("text_config")
|
||||
if key in text_config.__dict__:
|
||||
return getattr(text_config, key)
|
||||
|
||||
return super().__getattribute__(key)
|
||||
@ -9,7 +9,15 @@ reasons:
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
||||
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
|
||||
from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
|
||||
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||
|
||||
__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
|
||||
__all__ = [
|
||||
"DeepseekVLV2Processor",
|
||||
"HunYuanVLProcessor",
|
||||
"HunYuanVLImageProcessor",
|
||||
"OvisProcessor",
|
||||
"Ovis2_5Processor",
|
||||
]
|
||||
|
||||
233
vllm/transformers_utils/processors/hunyuan_vl.py
Normal file
233
vllm/transformers_utils/processors/hunyuan_vl.py
Normal file
@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/processing_hunyuan_vl.py
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
|
||||
class HunYuanVLProcessor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer" # ("AutoTokenizer", None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
video_processor=None,
|
||||
chat_template=None,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO Fix the init
|
||||
self.tokenizer = tokenizer
|
||||
self.image_token_id = 120120 # self.tokenizer.image_token_id
|
||||
self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id)
|
||||
self.im_start_token_id = 120118 # self.tokenizer.im_start_id
|
||||
self.im_start_token = self.tokenizer.convert_ids_to_tokens(
|
||||
self.im_start_token_id
|
||||
)
|
||||
self.im_end_token_id = 120119 # self.tokenizer.im_end_id
|
||||
self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id)
|
||||
self.placeholder_token = self.tokenizer.convert_ids_to_tokens(
|
||||
self.tokenizer.vocab_size - 1
|
||||
)
|
||||
self.pad_id = 120002 # self.tokenizer.pad_token_id
|
||||
|
||||
super().__init__(
|
||||
image_processor, tokenizer, video_processor, chat_template=chat_template
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput
|
||||
| PreTokenizedInput
|
||||
| list[TextInput]
|
||||
| list[PreTokenizedInput] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
image_inputs = {}
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
|
||||
image_tokens_cumsum = [0]
|
||||
if images is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
grid_h, grid_w = image_grid_thw[index][-2:]
|
||||
patch_h = grid_h // self.image_processor.merge_size
|
||||
patch_w = grid_w // self.image_processor.merge_size
|
||||
num_image_tokens = patch_h * (patch_w + 1) + 2
|
||||
image_tokens_cumsum.append(
|
||||
image_tokens_cumsum[-1] + num_image_tokens
|
||||
)
|
||||
# text[i] = text[i].replace(self.image_token, self.im_start_token + self.placeholder_token * num_image_tokens + self.im_end_token, 1) # noqa: E501
|
||||
text[i] = text[i].replace(
|
||||
self.image_token, self.placeholder_token * num_image_tokens, 1
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace(self.placeholder_token, self.image_token)
|
||||
# text[i] = self.tokenizer.bos_token + text[i]
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs)
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||
|
||||
input_ids = text_inputs["input_ids"]
|
||||
position_ids = torch.arange(len(input_ids[0]))
|
||||
position_ids_w = torch.arange(len(input_ids[0]))
|
||||
position_ids_h = torch.arange(len(input_ids[0]))
|
||||
position_ids_t = torch.arange(len(input_ids[0]))
|
||||
|
||||
if images is not None:
|
||||
image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[
|
||||
0
|
||||
]
|
||||
for i in range(len(image_grid_thw)):
|
||||
grid_h, grid_w = image_grid_thw[i][-2:]
|
||||
patch_h = grid_h // self.image_processor.merge_size
|
||||
patch_w = grid_w // self.image_processor.merge_size
|
||||
start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1
|
||||
replace_num = (patch_w + 1) * patch_h
|
||||
position_ids_w[start_pos : start_pos + replace_num] = torch.tensor(
|
||||
list(range(patch_w + 1)) * patch_h, dtype=torch.int64
|
||||
)
|
||||
patch_h_list = []
|
||||
for h in range(patch_h):
|
||||
patch_h_list += [h] * (patch_w + 1)
|
||||
position_ids_h[start_pos : start_pos + replace_num] = torch.tensor(
|
||||
patch_h_list, dtype=torch.int64
|
||||
)
|
||||
position_ids_t[start_pos : start_pos + replace_num] = 0
|
||||
|
||||
position_ids = torch.stack(
|
||||
[position_ids, position_ids_w, position_ids_h, position_ids_t]
|
||||
).unsqueeze(0)
|
||||
text_inputs["position_ids"] = position_ids
|
||||
|
||||
attention_mask = input_ids.ne(self.pad_id)
|
||||
text_inputs["attention_mask"] = attention_mask
|
||||
text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)]
|
||||
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
|
||||
|
||||
return_tensors = kwargs.pop("return_tensors", None)
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self,
|
||||
generated_outputs,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
assert 0
|
||||
|
||||
def apply_chat_template(self, *args, **kwargs):
|
||||
token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)
|
||||
return token_ids
|
||||
|
||||
def get_imgs_pos(self, doc_ids):
|
||||
doc_ids = np.array(doc_ids, dtype=np.int64)
|
||||
img_begin_index = np.where(doc_ids == self.im_start_token_id)[0]
|
||||
img_end_index = np.where(doc_ids == self.im_end_token_id)[0]
|
||||
imgs_pos = np.concatenate(
|
||||
(
|
||||
np.reshape(img_begin_index + 1, (-1, 1)),
|
||||
np.reshape(img_end_index, (-1, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
).tolist()
|
||||
return imgs_pos
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
def split_image_into_patch_blocks(
|
||||
pixel_values: torch.Tensor, # shape: [batch_size, 3, H, W]
|
||||
patch_size: int = 16, # e.g. 16
|
||||
adaptor_patch_div: int = 4, # e.g. 4 --> each patch_size is cut into 4x4 small regions, i.e. patch_size // 4 # noqa: E501
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Split the input image tensor (supporting batch) into large patches of size `patch_size`,
|
||||
and then further divide each large patch into smaller regions of size
|
||||
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div).
|
||||
Each small region is extracted as a tensor of shape [3, patch_size, patch_size].
|
||||
The final output contains all such small region tensors.
|
||||
|
||||
Args:
|
||||
pixel_values: Input image tensor of shape [batch_size, 3, H, W].
|
||||
patch_size: Size of the large patch, e.g., 16.
|
||||
adaptor_patch_div: Each large patch is divided into
|
||||
(patch_size // adaptor_patch_div) x (patch_size // adaptor_patch_div)
|
||||
smaller regions.
|
||||
|
||||
Returns:
|
||||
patches: A tensor of shape [N, 3, patch_size, patch_size],
|
||||
where N = batch_size * (H // patch_size) * (W // patch_size) * (patch_size // adaptor_patch_div)^2.
|
||||
Each element in the batch corresponds to one small image region.
|
||||
""" # noqa: E501
|
||||
batch_size, channels, height, width = pixel_values.shape
|
||||
assert channels == 3, "Pixel values must have 3 channels in dim=1"
|
||||
assert height % patch_size == 0 and width % patch_size == 0, (
|
||||
"H and W must be divisible by patch_size"
|
||||
)
|
||||
|
||||
patch_height_num = height // patch_size
|
||||
patch_width_num = width // patch_size
|
||||
|
||||
# Reshape to [B, 3, ph, ps, pw, ps]
|
||||
img = pixel_values.reshape(
|
||||
batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size
|
||||
)
|
||||
|
||||
# Further split each psxps patch into (ps//aps)x(ps//aps) small regions
|
||||
img = img.reshape(
|
||||
batch_size,
|
||||
3,
|
||||
patch_height_num,
|
||||
patch_size // adaptor_patch_div, # ps // aps
|
||||
adaptor_patch_div,
|
||||
patch_width_num,
|
||||
patch_size // adaptor_patch_div, # ps // aps
|
||||
adaptor_patch_div,
|
||||
)
|
||||
|
||||
# Permute to group the small regions: [B, ph, pw, ps//aps, ps//aps, 3, aps, aps]
|
||||
img = img.permute(0, 2, 5, 3, 6, 1, 4, 7)
|
||||
|
||||
# Reshape into [B * ph * pw * (ps//aps)^2, 3, patch_size, patch_size]
|
||||
patches = img.reshape(-1, 3, patch_size, patch_size)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
AutoProcessor.register("HunYuanVLProcessor", HunYuanVLProcessor)
|
||||
477
vllm/transformers_utils/processors/hunyuan_vl_image.py
Normal file
477
vllm/transformers_utils/processors/hunyuan_vl_image.py
Normal file
@ -0,0 +1,477 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://github.com/ManaEstras/transformers/blob/v4.57.1.hyvl/src/transformers/models/hunyuan_vl/image_processing_hunyuan_vl.py
|
||||
"""Image processor class for HunYuanVL."""
|
||||
|
||||
# isort conflicts with ruff for transformers imports
|
||||
# isort: skip_file
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from transformers.image_transforms import (
|
||||
convert_to_rgb,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from transformers.utils import TensorType, logging
|
||||
from transformers.video_utils import VideoInput, make_batched_videos
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 16,
|
||||
min_pixels: int = 512 * 512,
|
||||
max_pixels: int = 2048 * 2048,
|
||||
):
|
||||
"""Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
|
||||
"""
|
||||
if max(height, width) / min(height, width) > 200:
|
||||
raise ValueError(
|
||||
"absolute aspect ratio must be smaller than 200, got "
|
||||
f"{max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
class HunYuanVLImageProcessor(BaseImageProcessor):
|
||||
model_input_names = [
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: int | float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool = True,
|
||||
min_pixels: int | None = None,
|
||||
max_pixels: int | None = None,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if size is not None and (
|
||||
"shortest_edge" not in size or "longest_edge" not in size
|
||||
):
|
||||
raise ValueError(
|
||||
"size must contain 'shortest_edge' and 'longest_edge' keys."
|
||||
)
|
||||
else:
|
||||
size = {"shortest_edge": 512 * 512, "longest_edge": 2048 * 2048}
|
||||
# backward compatibility: override size with min_pixels and max_pixels
|
||||
# if they are provided.
|
||||
if min_pixels is not None:
|
||||
size["shortest_edge"] = min_pixels
|
||||
if max_pixels is not None:
|
||||
size["longest_edge"] = max_pixels
|
||||
self.min_pixels = size["shortest_edge"]
|
||||
self.max_pixels = size["longest_edge"]
|
||||
self.size = size
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.merge_size = merge_size
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
# hard-code
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput | VideoInput,
|
||||
do_resize: bool | None = None,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool | None = None,
|
||||
rescale_factor: float | None = None,
|
||||
do_normalize: bool | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
do_convert_rgb: bool | None = None,
|
||||
data_format: ChannelDimension | None = ChannelDimension.FIRST,
|
||||
input_data_format: str | ChannelDimension | None = None,
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Scale factor to use if rescaling the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
""" # noqa: E501
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
width, height = images[0].width, images[0].height
|
||||
resized_width, resized_height = width, height
|
||||
processed_images = []
|
||||
for image in images:
|
||||
if do_resize:
|
||||
resized_width, resized_height = smart_resize(
|
||||
width,
|
||||
height,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=self.min_pixels,
|
||||
max_pixels=self.max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
if do_normalize:
|
||||
image = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(self.image_mean, self.image_std),
|
||||
]
|
||||
)(image)
|
||||
processed_images.append(image)
|
||||
|
||||
patches = np.array(processed_images)
|
||||
channel = patches.shape[1]
|
||||
grid_t = patches.shape[0] // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
patches = patches.reshape(
|
||||
1,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.transpose(0, 2, 3, 5, 6, 1, 4, 7)
|
||||
flatten_patches = patches.reshape(
|
||||
1 * grid_h * grid_w, channel * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, (grid_t, grid_h, grid_w)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
do_resize: bool | None = None,
|
||||
size: dict[str, int] | None = None,
|
||||
min_pixels: int | None = None,
|
||||
max_pixels: int | None = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool | None = None,
|
||||
rescale_factor: float | None = None,
|
||||
do_normalize: bool | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
patch_size: int | None = None,
|
||||
temporal_patch_size: int | None = None,
|
||||
merge_size: int | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
data_format: ChannelDimension | None = ChannelDimension.FIRST,
|
||||
input_data_format: str | ChannelDimension | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
videos (`VideoInput`):
|
||||
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
||||
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
|
||||
The min pixels of the image to resize the image.
|
||||
max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
""" # noqa: E501
|
||||
min_pixels = min_pixels if min_pixels is not None else self.min_pixels
|
||||
max_pixels = max_pixels if max_pixels is not None else self.max_pixels
|
||||
|
||||
if size is not None:
|
||||
if "shortest_edge" not in size or "longest_edge" not in size:
|
||||
raise ValueError(
|
||||
"size must contain 'shortest_edge' and 'longest_edge' keys."
|
||||
)
|
||||
min_pixels = size["shortest_edge"]
|
||||
elif min_pixels is not None and max_pixels is not None:
|
||||
# backward compatibility: override size with min_pixels and max_pixels
|
||||
# if they are provided.
|
||||
size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = (
|
||||
rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
)
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
temporal_patch_size = (
|
||||
temporal_patch_size
|
||||
if temporal_patch_size is not None
|
||||
else self.temporal_patch_size
|
||||
)
|
||||
merge_size = merge_size if merge_size is not None else self.merge_size
|
||||
do_convert_rgb = (
|
||||
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
for image in images:
|
||||
patches, image_grid_thw = self._preprocess(
|
||||
image,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
data_format=data_format,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
pixel_values = np.array(pixel_values)
|
||||
vision_grid_thws = np.array(vision_grid_thws)
|
||||
data.update(
|
||||
{"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
||||
)
|
||||
|
||||
# kept for BC only and should be removed after v5.0
|
||||
if videos is not None:
|
||||
logger.warning(
|
||||
"`HunYuanVLV1ImageProcessor` works only with image inputs "
|
||||
"and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `HunYuanVLV1VideoProcessor`. "
|
||||
)
|
||||
videos = make_batched_videos(videos)
|
||||
pixel_values_videos, vision_grid_thws_videos = [], []
|
||||
for images in videos:
|
||||
patches, video_grid_thw = self._preprocess(
|
||||
images,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
data_format=data_format,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pixel_values_videos.extend(patches)
|
||||
vision_grid_thws_videos.append(video_grid_thw)
|
||||
data.update(
|
||||
{
|
||||
"pixel_values_videos": np.array(pixel_values_videos),
|
||||
"video_grid_thw": np.array(vision_grid_thws_videos),
|
||||
}
|
||||
)
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||
"""
|
||||
A utility that returns number of image patches for a given image size.
|
||||
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
images_kwargs (`dict`, *optional*):
|
||||
Any kwargs to override defaults of the image processor.
|
||||
Returns:
|
||||
`int`: Number of image patches per image.
|
||||
"""
|
||||
min_pixels = (
|
||||
images_kwargs["min_pixels"]
|
||||
if "min_pixels" in images_kwargs
|
||||
else self.size["shortest_edge"]
|
||||
)
|
||||
max_pixels = (
|
||||
images_kwargs["max_pixels"]
|
||||
if "max_pixels" in images_kwargs
|
||||
else self.size["longest_edge"]
|
||||
)
|
||||
patch_size = images_kwargs.get("patch_size", self.patch_size)
|
||||
merge_size = images_kwargs.get("merge_size", self.merge_size)
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * (grid_w + 1) + 2
|
||||
|
||||
|
||||
AutoImageProcessor.register("HunYuanVLImageProcessor", HunYuanVLImageProcessor)
|
||||
@ -43,6 +43,8 @@ class CachedRequestState:
|
||||
mrope_positions: torch.Tensor | None = None
|
||||
mrope_position_delta: int | None = None
|
||||
|
||||
xdrope_positions: torch.Tensor | None = None
|
||||
|
||||
lora_request: LoRARequest | None = None
|
||||
prompt_embeds: torch.Tensor | None = None
|
||||
|
||||
|
||||
@ -50,16 +50,21 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
MRotaryEmbedding,
|
||||
XDRotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsXDRoPE,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mrope,
|
||||
supports_multimodal_pruning,
|
||||
supports_transcription,
|
||||
supports_xdrope,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling,
|
||||
@ -324,6 +329,7 @@ class GPUModelRunner(
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
self.uses_xdrope_dim = model_config.uses_xdrope_dim
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
model_config
|
||||
)
|
||||
@ -512,6 +518,13 @@ class GPUModelRunner(
|
||||
(3, self.max_num_tokens + 1), dtype=torch.int64
|
||||
)
|
||||
|
||||
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
|
||||
if self.uses_xdrope_dim > 0:
|
||||
# Similar to mrope but use assigned dimension number for RoPE, 4 as default.
|
||||
self.xdrope_positions = self._make_buffer(
|
||||
(self.uses_xdrope_dim, self.max_num_tokens + 1), dtype=torch.int64
|
||||
)
|
||||
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
|
||||
@ -593,10 +606,14 @@ class GPUModelRunner(
|
||||
if isinstance(num_tokens, int):
|
||||
if self.uses_mrope:
|
||||
return self.mrope_positions.gpu[:, :num_tokens]
|
||||
if self.uses_xdrope_dim > 0:
|
||||
return self.xdrope_positions.gpu[:, :num_tokens]
|
||||
return self.positions.gpu[:num_tokens]
|
||||
else:
|
||||
if self.uses_mrope:
|
||||
return self.mrope_positions.gpu[:, num_tokens]
|
||||
if self.uses_xdrope_dim > 0:
|
||||
return self.xdrope_positions.gpu[:, num_tokens]
|
||||
return self.positions.gpu[num_tokens]
|
||||
|
||||
def _make_buffer(
|
||||
@ -772,6 +789,10 @@ class GPUModelRunner(
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(req_state)
|
||||
|
||||
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
|
||||
if self.uses_xdrope_dim > 0:
|
||||
self._init_xdrope_positions(req_state)
|
||||
|
||||
reqs_to_add.append(req_state)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
@ -987,6 +1008,19 @@ class GPUModelRunner(
|
||||
)
|
||||
)
|
||||
|
||||
def _init_xdrope_positions(self, req_state: CachedRequestState):
|
||||
model = self.get_model()
|
||||
xdrope_model = cast(SupportsXDRoPE, model)
|
||||
assert req_state.prompt_token_ids is not None, (
|
||||
"XD-RoPE requires prompt_token_ids to be available."
|
||||
)
|
||||
assert supports_xdrope(model), "XD-RoPE support is not implemented."
|
||||
|
||||
req_state.xdrope_positions = xdrope_model.get_xdrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
req_state.mm_features,
|
||||
)
|
||||
|
||||
def _extract_mm_kwargs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -1231,6 +1265,11 @@ class GPUModelRunner(
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Calculate XD-RoPE positions.
|
||||
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
|
||||
if self.uses_xdrope_dim > 0:
|
||||
self._calc_xdrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
@ -1364,6 +1403,12 @@ class GPUModelRunner(
|
||||
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
|
||||
self.xdrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
||||
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
@ -1793,6 +1838,53 @@ class GPUModelRunner(
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
def _calc_xdrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
xdrope_pos_ptr = 0
|
||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||
req = self.requests[req_id]
|
||||
assert req.xdrope_positions is not None
|
||||
|
||||
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
req.prompt_token_ids, req.prompt_embeds
|
||||
)
|
||||
|
||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||
prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
|
||||
completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
|
||||
else:
|
||||
prompt_part_len = num_scheduled_tokens
|
||||
completion_part_len = 0
|
||||
|
||||
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
||||
|
||||
if prompt_part_len > 0:
|
||||
# prompt's xdrope_positions are pre-computed
|
||||
dst_start = xdrope_pos_ptr
|
||||
dst_end = xdrope_pos_ptr + prompt_part_len
|
||||
src_start = num_computed_tokens
|
||||
src_end = num_computed_tokens + prompt_part_len
|
||||
|
||||
self.xdrope_positions.cpu[:, dst_start:dst_end] = req.xdrope_positions[
|
||||
:, src_start:src_end
|
||||
]
|
||||
xdrope_pos_ptr += prompt_part_len
|
||||
|
||||
if completion_part_len > 0:
|
||||
# compute completion's xdrope_positions on-the-fly
|
||||
dst_start = xdrope_pos_ptr
|
||||
dst_end = xdrope_pos_ptr + completion_part_len
|
||||
|
||||
XDRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.xdrope_positions.np,
|
||||
out_offset=dst_start,
|
||||
context_len=num_computed_tokens + prompt_part_len,
|
||||
num_new_tokens=completion_part_len,
|
||||
)
|
||||
|
||||
xdrope_pos_ptr += completion_part_len
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
num_draft_tokens: np.ndarray,
|
||||
@ -2037,6 +2129,7 @@ class GPUModelRunner(
|
||||
|
||||
req_start_idx = 0
|
||||
should_sync_mrope_positions = False
|
||||
should_sync_xdrope_positions = False
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mm_embeds_req: list[torch.Tensor] = []
|
||||
@ -2110,6 +2203,10 @@ class GPUModelRunner(
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
if should_sync_xdrope_positions:
|
||||
self._calc_xdrope_positions(scheduler_output)
|
||||
self.xdrope_positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@ -2384,8 +2481,11 @@ class GPUModelRunner(
|
||||
input_ids = self.input_ids.gpu[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_input_tokens]
|
||||
|
||||
@ -3824,6 +3924,8 @@ class GPUModelRunner(
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_tokens_after_padding]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user