[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Shanshan Shen 2025-12-16 11:08:16 +08:00 committed by GitHub
parent ff21a0fc85
commit 3bd9c49158
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 553 additions and 280 deletions

View File

@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
This test ensures that RotaryEmbedding classes correctly call the appropriate
ApplyRotaryEmb methods based on the calling context:
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
"""
from dataclasses import dataclass
import pytest
import torch
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.platforms import current_platform
CUDA_DEVICES = ["cuda:0"]
@dataclass
class RotaryEmbeddingTestCase:
"""Test case configuration for RotaryEmbedding dispatch tests."""
name: str
rope_class: type
rope_kwargs: dict
method_name: str # forward_native, forward_cuda, forward
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
expect_forward: bool # Should call ApplyRotaryEmb.forward()
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
"""Generate test cases for all RotaryEmbedding classes."""
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
Ernie4_5_VLRotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
common_kwargs = {
"head_size": 128,
"rotary_dim": 128,
"max_position_embeddings": 4096,
"base": 10000,
"is_neox_style": True,
"dtype": torch.bfloat16,
}
return [
# MRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_native",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_cuda_1d",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_cuda",
positions_shape=(32,), # 1D triggers apply_rotary_emb path
expect_forward_native=False,
expect_forward=True,
),
# XDRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="XDRotaryEmbedding.forward",
rope_class=XDRotaryEmbedding,
rope_kwargs={
**common_kwargs,
"scaling_alpha": 1.0,
"xdrope_section": [16, 16, 16, 16],
},
method_name="forward",
positions_shape=(4, 32), # 4D for P/W/H/T
expect_forward_native=False,
expect_forward=True,
),
# Ernie4_5_VLRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="Ernie4_5_VLRotaryEmbedding.forward_native",
rope_class=Ernie4_5_VLRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
]
def run_dispatch_test(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""Run a dispatch test for a RotaryEmbedding class."""
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
)
get_cached_compilation_config.cache_clear()
with set_current_vllm_config(vllm_config):
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
apply_rotary_emb = rope.apply_rotary_emb
# Verify custom op is enabled
if test_case.expect_forward_native:
assert (
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
# Setup call tracking
call_tracker = {"forward_native_called": False, "forward_called": False}
original_forward_native = apply_rotary_emb.forward_native
original_forward = apply_rotary_emb.forward
def tracked_forward_native(*args, **kwargs):
call_tracker["forward_native_called"] = True
return original_forward_native(*args, **kwargs)
def tracked_forward(*args, **kwargs):
call_tracker["forward_called"] = True
return original_forward(*args, **kwargs)
apply_rotary_emb.forward_native = tracked_forward_native
apply_rotary_emb.forward = tracked_forward
try:
num_tokens = test_case.positions_shape[-1]
num_q_heads = 8
num_kv_heads = 2
head_size = test_case.rope_kwargs["head_size"]
max_position = test_case.rope_kwargs["max_position_embeddings"]
positions = torch.randint(
0, max_position // 4, test_case.positions_shape, device=device
)
query = torch.randn(
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
num_tokens,
num_kv_heads * head_size,
dtype=torch.bfloat16,
device=device,
)
# Call the method under test
method = getattr(rope, test_case.method_name)
method(positions, query.clone(), key.clone())
# Verify expectations
if test_case.expect_forward_native:
assert call_tracker["forward_native_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
)
if not test_case.expect_forward:
assert not call_tracker["forward_called"], (
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
"Bug: when +apply_rotary_emb is enabled, forward_native() "
"incorrectly dispatches to CUDA/HIP kernels."
)
if test_case.expect_forward:
assert call_tracker["forward_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward()"
)
finally:
apply_rotary_emb.forward_native = original_forward_native
apply_rotary_emb.forward = original_forward
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rotary_embedding_dispatch(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
- forward_native methods should call ApplyRotaryEmb.forward_native()
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
"""
run_dispatch_test(test_case, device)

View File

@ -7,7 +7,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .common import ApplyRotaryEmb
@CustomOp.register("rotary_embedding")
@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query_rot = ApplyRotaryEmb.forward_static(
query_rot,
cos,
sin,
is_neox_style,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key_rot = ApplyRotaryEmb.forward_static(
key_rot,
cos,
sin,
is_neox_style,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
from functools import cache
from importlib.util import find_spec
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
logger = init_logger(__name__)
@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def apply_rotary_emb_dispatch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
else:
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@cache
def dispatch_rotary_emb_function(
default: Callable[..., torch.Tensor] | None = None,
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
if default is not None:
return default
return apply_rotary_emb_torch
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
@ -186,3 +116,155 @@ direct_register_custom_op(
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
)
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
def __init__(
self,
enforce_enable: bool = False,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(enforce_enable)
self.is_neox_style = is_neox_style
self.enable_fp32_compute = enable_fp32_compute
self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
self.apply_rotary_emb_flash_attn = apply_rotary
@staticmethod
def forward_static(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> torch.Tensor:
"""
Args:
x: [batch_size (optional), seq_len, num_heads, head_size]
cos: [seq_len, head_size // 2]
sin: [seq_len, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style.
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
for higher accuracy.
"""
origin_dtype = x.dtype
if enable_fp32_compute:
x = x.float()
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
output = torch.cat((o1, o2), dim=-1)
else:
output = torch.stack((o1, o2), dim=-1).flatten(-2)
if enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_native(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
output = self.forward_static(
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
)
return output
def forward_cuda(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_hip(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
if self.apply_rotary_emb_flash_attn is not None:
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = self.apply_rotary_emb_flash_attn(
x, cos, sin, interleaved=interleaved
).type_as(x)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
else:
# Falling back to PyTorch native implementation.
output = self.forward_native(x, cos, sin)
return output
def extra_repr(self) -> str:
s = f"is_neox_style={self.is_neox_style}"
s += f"enable_fp32_compute={self.enable_fp32_compute}"
return s

View File

@ -4,7 +4,6 @@
import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding
@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@ -8,7 +8,6 @@ import torch
from vllm.triton_utils import tl, triton
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@ -4,7 +4,6 @@
import numpy as np
import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
dtype,
)
def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
return processor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(

View File

@ -33,7 +33,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@ -89,52 +91,6 @@ logger = init_logger(__name__)
# === Vision Transformer === #
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
output = self.attn(

View File

@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -95,7 +98,7 @@ from .interfaces import (
SupportsMultiModal,
SupportsPP,
)
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
from .qwen2_vl import _create_qwen2vl_field_factory
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed

View File

@ -22,7 +22,7 @@ from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.modeling_outputs import (
@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
@ -130,47 +130,6 @@ def smart_resize(
return h_bar, w_bar
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(

View File

@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (
Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision,
)
from .utils import (
AutoWeightsLoader,
@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def forward(
self,
x: torch.Tensor,
@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
qk_reshaped = einops.rearrange(
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
)
qk_rotated = apply_rotary_pos_emb_vision(
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_reshaped,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
qk_rotated = qk_rotated.view(
2,

View File

@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
dispatch_rotary_emb_function,
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
return x
def apply_rotary_pos_emb_vision(
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(
default=partial(apply_rotary_emb_torch, is_neox_style=True)
)
output = rotary_emb_function(t, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@ -6,7 +6,6 @@ within a vision language model."""
from collections.abc import Iterable
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig
@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and not current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
apply_rotary_emb_func = apply_rotary_emb.forward_native
q_embed = apply_rotary_emb_func(q, cos, sin)
k_embed = apply_rotary_emb_func(k, cos, sin)
return q_embed, k_embed