diff --git a/tests/kernels/core/test_apply_rotary_emb.py b/tests/kernels/core/test_apply_rotary_emb.py new file mode 100644 index 0000000000000..23c722fa5e638 --- /dev/null +++ b/tests/kernels/core/test_apply_rotary_emb.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for ApplyRotaryEmb CustomOp dispatch behavior. + +This test ensures that RotaryEmbedding classes correctly call the appropriate +ApplyRotaryEmb methods based on the calling context: + +1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native() +2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch) +3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch) +""" + +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ( + CompilationConfig, + VllmConfig, + get_cached_compilation_config, + set_current_vllm_config, +) +from vllm.platforms import current_platform + +CUDA_DEVICES = ["cuda:0"] + + +@dataclass +class RotaryEmbeddingTestCase: + """Test case configuration for RotaryEmbedding dispatch tests.""" + + name: str + rope_class: type + rope_kwargs: dict + method_name: str # forward_native, forward_cuda, forward + positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens) + expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native() + expect_forward: bool # Should call ApplyRotaryEmb.forward() + + +def get_test_cases() -> list[RotaryEmbeddingTestCase]: + """Generate test cases for all RotaryEmbedding classes.""" + from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding, + ) + from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding + + common_kwargs = { + "head_size": 128, + "rotary_dim": 128, + "max_position_embeddings": 4096, + "base": 10000, + "is_neox_style": True, + "dtype": torch.bfloat16, + } + + return [ + # MRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_native", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_cuda_1d", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_cuda", + positions_shape=(32,), # 1D triggers apply_rotary_emb path + expect_forward_native=False, + expect_forward=True, + ), + # XDRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="XDRotaryEmbedding.forward", + rope_class=XDRotaryEmbedding, + rope_kwargs={ + **common_kwargs, + "scaling_alpha": 1.0, + "xdrope_section": [16, 16, 16, 16], + }, + method_name="forward", + positions_shape=(4, 32), # 4D for P/W/H/T + expect_forward_native=False, + expect_forward=True, + ), + # Ernie4_5_VLRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="Ernie4_5_VLRotaryEmbedding.forward_native", + rope_class=Ernie4_5_VLRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + ] + + +def run_dispatch_test( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """Run a dispatch test for a RotaryEmbedding class.""" + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"]) + ) + get_cached_compilation_config.cache_clear() + + with set_current_vllm_config(vllm_config): + rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device) + + apply_rotary_emb = rope.apply_rotary_emb + + # Verify custom op is enabled + if test_case.expect_forward_native: + assert ( + apply_rotary_emb._forward_method != apply_rotary_emb.forward_native + ), "Test setup error: ApplyRotaryEmb custom op should be enabled" + + # Setup call tracking + call_tracker = {"forward_native_called": False, "forward_called": False} + original_forward_native = apply_rotary_emb.forward_native + original_forward = apply_rotary_emb.forward + + def tracked_forward_native(*args, **kwargs): + call_tracker["forward_native_called"] = True + return original_forward_native(*args, **kwargs) + + def tracked_forward(*args, **kwargs): + call_tracker["forward_called"] = True + return original_forward(*args, **kwargs) + + apply_rotary_emb.forward_native = tracked_forward_native + apply_rotary_emb.forward = tracked_forward + + try: + num_tokens = test_case.positions_shape[-1] + num_q_heads = 8 + num_kv_heads = 2 + head_size = test_case.rope_kwargs["head_size"] + max_position = test_case.rope_kwargs["max_position_embeddings"] + + positions = torch.randint( + 0, max_position // 4, test_case.positions_shape, device=device + ) + query = torch.randn( + num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device + ) + key = torch.randn( + num_tokens, + num_kv_heads * head_size, + dtype=torch.bfloat16, + device=device, + ) + + # Call the method under test + method = getattr(rope, test_case.method_name) + method(positions, query.clone(), key.clone()) + + # Verify expectations + if test_case.expect_forward_native: + assert call_tracker["forward_native_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward_native()" + ) + if not test_case.expect_forward: + assert not call_tracker["forward_called"], ( + f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). " + "Bug: when +apply_rotary_emb is enabled, forward_native() " + "incorrectly dispatches to CUDA/HIP kernels." + ) + if test_case.expect_forward: + assert call_tracker["forward_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward()" + ) + finally: + apply_rotary_emb.forward_native = original_forward_native + apply_rotary_emb.forward = original_forward + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_rotary_embedding_dispatch( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """ + Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method. + + - forward_native methods should call ApplyRotaryEmb.forward_native() + - forward_cuda/forward methods should call ApplyRotaryEmb.forward() + """ + run_dispatch_test(test_case, device) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 4114b21168cc8..afa69324c4e2e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -7,7 +7,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_torch +from .common import ApplyRotaryEmb @CustomOp.register("rotary_embedding") @@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp): rocm_aiter_ops.is_triton_rotary_embed_enabled() ) + self.apply_rotary_emb = ApplyRotaryEmb( + is_neox_style=self.is_neox_style, + ) + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) + query_rot = ApplyRotaryEmb.forward_static( + query_rot, + cos, + sin, + is_neox_style, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing @@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) + key_rot = ApplyRotaryEmb.forward_static( + key_rot, + cos, + sin, + is_neox_style, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 13f8d15cc0f72..3e6584dbc3da0 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,19 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from collections.abc import Callable -from functools import cache from importlib.util import find_spec import torch from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.model_executor.custom_op import CustomOp from vllm.utils.torch_utils import direct_register_custom_op -if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - logger = init_logger(__name__) @@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def apply_rotary_emb_dispatch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) - else: - return apply_rotary_emb_torch(x, cos, sin, is_neox_style) - - -@cache -def dispatch_rotary_emb_function( - default: Callable[..., torch.Tensor] | None = None, -) -> Callable[..., torch.Tensor]: - if current_platform.is_cuda(): - return apply_rotary_emb - - # if torch compile is not enabled - # use rotary embedding function from flash_attn package - # otherwise use the naive pytorch embedding implementation - # is faster when torch compile is enabled. - if current_platform.is_rocm() and not torch.compiler.is_compiling(): - if find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary - - return apply_rotary - else: - logger.warning( - "flash_attn is not installed. Falling back to PyTorch " - "implementation for rotary embeddings." - ) - if default is not None: - return default - - return apply_rotary_emb_torch - - # yarn functions # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( @@ -186,3 +116,155 @@ direct_register_custom_op( mutates_args=["query", "key"], # These tensors are modified in-place fake_impl=_flashinfer_rotary_embedding_fake, ) + + +@CustomOp.register("apply_rotary_emb") +class ApplyRotaryEmb(CustomOp): + def __init__( + self, + enforce_enable: bool = False, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> None: + super().__init__(enforce_enable) + self.is_neox_style = is_neox_style + self.enable_fp32_compute = enable_fp32_compute + + self.apply_rotary_emb_flash_attn = None + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + self.apply_rotary_emb_flash_attn = apply_rotary + + @staticmethod + def forward_static( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> torch.Tensor: + """ + Args: + x: [batch_size (optional), seq_len, num_heads, head_size] + cos: [seq_len, head_size // 2] + sin: [seq_len, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style. + enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype + for higher accuracy. + """ + origin_dtype = x.dtype + if enable_fp32_compute: + x = x.float() + + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + + if is_neox_style: + output = torch.cat((o1, o2), dim=-1) + else: + output = torch.stack((o1, o2), dim=-1).flatten(-2) + + if enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_native( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + output = self.forward_static( + x, cos, sin, self.is_neox_style, self.enable_fp32_compute + ) + return output + + def forward_cuda( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary_emb() in vllm_flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = apply_rotary_emb(x, cos, sin, interleaved) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_hip( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + if self.apply_rotary_emb_flash_attn is not None: + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary() in flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = self.apply_rotary_emb_flash_attn( + x, cos, sin, interleaved=interleaved + ).type_as(x) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + else: + # Falling back to PyTorch native implementation. + output = self.forward_native(x, cos, sin) + + return output + + def extra_repr(self) -> str: + s = f"is_neox_style={self.is_neox_style}" + s += f"enable_fp32_compute={self.enable_fp32_compute}" + return s diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 749cdbe88a62e..2eda63a34ac44 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -4,7 +4,6 @@ import torch -from .common import apply_rotary_emb_dispatch from .mrope import MRotaryEmbedding @@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0592aa8f967a6..a74bf092b182b 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import torch from vllm.triton_utils import tl, triton from .base import RotaryEmbeddingBase -from .common import apply_rotary_emb_dispatch from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py index 2432273faf195..dab7aad9759a2 100644 --- a/vllm/model_executor/layers/rotary_embedding/xdrope.py +++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py @@ -4,7 +4,6 @@ import numpy as np import torch -from .common import apply_rotary_emb_dispatch from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding @@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): dtype, ) - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [4, num_tokens] (P/W/H/T positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1 + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1 + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 9b61cd9503073..6d8dbec9236c9 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): return processor -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - - cos = freqs.cos() - sin = freqs.sin() - - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - - output = (tensor * cos) + (rotate_half(tensor) * sin) - - output = output.to(orig_dtype) - - return output - - class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() @@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module): prefix=f"{prefix}.attn", ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) + def forward( self, hidden_states: torch.Tensor, @@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) context_layer = self.attn( diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index dd2b74736bcac..61cf78fdb5a67 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -33,7 +33,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum @@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -69,7 +72,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -89,52 +91,6 @@ logger = init_logger(__name__) # === Vision Transformer === # -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) - return output - - def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist @@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module): prefix=f"{prefix}.attn", ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) output = self.attn( diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 10e5261a30485..84989537da6e2 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -95,7 +98,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, ) -from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision +from .qwen2_vl import _create_qwen2vl_field_factory from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module): multimodal_config=multimodal_config, ) + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module): if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 52e4413690619..fcf88953ba20f 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -59,7 +62,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt( cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - elif current_platform.is_rocm(): - from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb - else: - # For other platforms, use PyTorch fallback - from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - ) + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) - apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) + q_embed = apply_rotary_emb(q, cos, sin) + k_embed = apply_rotary_emb(k, cos, sin) - q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 66acc0432d125..56565266c0dcc 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -22,7 +22,7 @@ from typing import Annotated, Literal import numpy as np import torch import torch.nn as nn -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature, PretrainedConfig from transformers.activations import GELUActivation from transformers.modeling_outputs import ( @@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -130,47 +130,6 @@ def smart_resize( return h_bar, w_bar -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = rotary_emb_function(t_, cos, sin).type_as(t) - return output - - class PaddleOCRVLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() @@ -609,6 +568,10 @@ class SiglipAttention(nn.Module): multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape @@ -651,7 +614,11 @@ class SiglipAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) context_layer = self.attn( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a5a47f81ba24d..b730ac0315893 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.vision import should_torch_compile_mm_vit @@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import ( Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision, ) from .utils import ( AutoWeightsLoader, @@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module): multimodal_config=multimodal_config, ) + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + def forward( self, x: torch.Tensor, @@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module): qk_reshaped = einops.rearrange( qk, "b s two head head_dim -> (two b) s head head_dim", two=2 ) - qk_rotated = apply_rotary_pos_emb_vision( - qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_reshaped, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) qk_rotated = qk_rotated.view( 2, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 192a54c3ec839..321fbd764c0f5 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module): return x -def apply_rotary_pos_emb_vision( - t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function( - default=partial(apply_rotary_emb_torch, is_neox_style=True) - ) - output = rotary_emb_function(t, cos, sin).type_as(t) - return output - - class Qwen2VisionAttention(nn.Module): def __init__( self, @@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module): multimodal_config=multimodal_config, ) + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module): # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 2ee21fc06846c..efdee255ab5eb 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -6,7 +6,6 @@ within a vision language model.""" from collections.abc import Iterable import torch -from einops import rearrange, repeat from torch import nn from torch.nn import functional as F from transformers import Siglip2VisionConfig @@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform @@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module): return patch_embeds -# copy from flash_attn/layers/rotary.py -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, @@ -189,14 +157,20 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend and current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - apply_rotary_emb_func = apply_rotary_emb + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) + + if is_flash_attn_backend and not current_platform.is_cuda(): + apply_rotary_emb_func = apply_rotary_emb.forward_cuda else: - apply_rotary_emb_func = apply_rotary_emb_torch - q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) + apply_rotary_emb_func = apply_rotary_emb.forward_native + + q_embed = apply_rotary_emb_func(q, cos, sin) + k_embed = apply_rotary_emb_func(k, cos, sin) + return q_embed, k_embed