mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 07:45:01 +08:00
[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
ff21a0fc85
commit
3bd9c49158
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
|
||||
|
||||
This test ensures that RotaryEmbedding classes correctly call the appropriate
|
||||
ApplyRotaryEmb methods based on the calling context:
|
||||
|
||||
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
|
||||
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RotaryEmbeddingTestCase:
|
||||
"""Test case configuration for RotaryEmbedding dispatch tests."""
|
||||
|
||||
name: str
|
||||
rope_class: type
|
||||
rope_kwargs: dict
|
||||
method_name: str # forward_native, forward_cuda, forward
|
||||
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
|
||||
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
|
||||
expect_forward: bool # Should call ApplyRotaryEmb.forward()
|
||||
|
||||
|
||||
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
|
||||
"""Generate test cases for all RotaryEmbedding classes."""
|
||||
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||
Ernie4_5_VLRotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
|
||||
|
||||
common_kwargs = {
|
||||
"head_size": 128,
|
||||
"rotary_dim": 128,
|
||||
"max_position_embeddings": 4096,
|
||||
"base": 10000,
|
||||
"is_neox_style": True,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
return [
|
||||
# MRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_native",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_cuda_1d",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_cuda",
|
||||
positions_shape=(32,), # 1D triggers apply_rotary_emb path
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# XDRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="XDRotaryEmbedding.forward",
|
||||
rope_class=XDRotaryEmbedding,
|
||||
rope_kwargs={
|
||||
**common_kwargs,
|
||||
"scaling_alpha": 1.0,
|
||||
"xdrope_section": [16, 16, 16, 16],
|
||||
},
|
||||
method_name="forward",
|
||||
positions_shape=(4, 32), # 4D for P/W/H/T
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# Ernie4_5_VLRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="Ernie4_5_VLRotaryEmbedding.forward_native",
|
||||
rope_class=Ernie4_5_VLRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_dispatch_test(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""Run a dispatch test for a RotaryEmbedding class."""
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
|
||||
)
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
|
||||
|
||||
apply_rotary_emb = rope.apply_rotary_emb
|
||||
|
||||
# Verify custom op is enabled
|
||||
if test_case.expect_forward_native:
|
||||
assert (
|
||||
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
|
||||
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
|
||||
|
||||
# Setup call tracking
|
||||
call_tracker = {"forward_native_called": False, "forward_called": False}
|
||||
original_forward_native = apply_rotary_emb.forward_native
|
||||
original_forward = apply_rotary_emb.forward
|
||||
|
||||
def tracked_forward_native(*args, **kwargs):
|
||||
call_tracker["forward_native_called"] = True
|
||||
return original_forward_native(*args, **kwargs)
|
||||
|
||||
def tracked_forward(*args, **kwargs):
|
||||
call_tracker["forward_called"] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
apply_rotary_emb.forward_native = tracked_forward_native
|
||||
apply_rotary_emb.forward = tracked_forward
|
||||
|
||||
try:
|
||||
num_tokens = test_case.positions_shape[-1]
|
||||
num_q_heads = 8
|
||||
num_kv_heads = 2
|
||||
head_size = test_case.rope_kwargs["head_size"]
|
||||
max_position = test_case.rope_kwargs["max_position_embeddings"]
|
||||
|
||||
positions = torch.randint(
|
||||
0, max_position // 4, test_case.positions_shape, device=device
|
||||
)
|
||||
query = torch.randn(
|
||||
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
num_tokens,
|
||||
num_kv_heads * head_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
method = getattr(rope, test_case.method_name)
|
||||
method(positions, query.clone(), key.clone())
|
||||
|
||||
# Verify expectations
|
||||
if test_case.expect_forward_native:
|
||||
assert call_tracker["forward_native_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
|
||||
)
|
||||
if not test_case.expect_forward:
|
||||
assert not call_tracker["forward_called"], (
|
||||
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
|
||||
"Bug: when +apply_rotary_emb is enabled, forward_native() "
|
||||
"incorrectly dispatches to CUDA/HIP kernels."
|
||||
)
|
||||
if test_case.expect_forward:
|
||||
assert call_tracker["forward_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward()"
|
||||
)
|
||||
finally:
|
||||
apply_rotary_emb.forward_native = original_forward_native
|
||||
apply_rotary_emb.forward = original_forward
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_rotary_embedding_dispatch(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""
|
||||
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
|
||||
|
||||
- forward_native methods should call ApplyRotaryEmb.forward_native()
|
||||
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
|
||||
"""
|
||||
run_dispatch_test(test_case, device)
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
from .common import ApplyRotaryEmb
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
is_neox_style=self.is_neox_style,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
|
||||
query_rot = ApplyRotaryEmb.forward_static(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
is_neox_style,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||
@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
|
||||
key_rot = ApplyRotaryEmb.forward_static(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
is_neox_style,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@ -2,19 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from importlib.util import find_spec
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
def apply_rotary_emb_dispatch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
|
||||
else:
|
||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
@cache
|
||||
def dispatch_rotary_emb_function(
|
||||
default: Callable[..., torch.Tensor] | None = None,
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if current_platform.is_cuda():
|
||||
return apply_rotary_emb
|
||||
|
||||
# if torch compile is not enabled
|
||||
# use rotary embedding function from flash_attn package
|
||||
# otherwise use the naive pytorch embedding implementation
|
||||
# is faster when torch compile is enabled.
|
||||
if current_platform.is_rocm() and not torch.compiler.is_compiling():
|
||||
if find_spec("flash_attn") is not None:
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
|
||||
return apply_rotary
|
||||
else:
|
||||
logger.warning(
|
||||
"flash_attn is not installed. Falling back to PyTorch "
|
||||
"implementation for rotary embeddings."
|
||||
)
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
return apply_rotary_emb_torch
|
||||
|
||||
|
||||
# yarn functions
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(
|
||||
@ -186,3 +116,155 @@ direct_register_custom_op(
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("apply_rotary_emb")
|
||||
class ApplyRotaryEmb(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
enforce_enable: bool = False,
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> None:
|
||||
super().__init__(enforce_enable)
|
||||
self.is_neox_style = is_neox_style
|
||||
self.enable_fp32_compute = enable_fp32_compute
|
||||
|
||||
self.apply_rotary_emb_flash_attn = None
|
||||
if find_spec("flash_attn") is not None:
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
|
||||
self.apply_rotary_emb_flash_attn = apply_rotary
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [batch_size (optional), seq_len, num_heads, head_size]
|
||||
cos: [seq_len, head_size // 2]
|
||||
sin: [seq_len, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style.
|
||||
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
|
||||
for higher accuracy.
|
||||
"""
|
||||
origin_dtype = x.dtype
|
||||
if enable_fp32_compute:
|
||||
x = x.float()
|
||||
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
|
||||
if is_neox_style:
|
||||
output = torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
output = torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
if enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = self.forward_static(
|
||||
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||
x: [batch_size, seq_len, nheads, headdim]
|
||||
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||
interleaved: defalut as False (Neox-style).
|
||||
...
|
||||
"""
|
||||
interleaved = not self.is_neox_style
|
||||
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.apply_rotary_emb_flash_attn is not None:
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary() in flash_attn:
|
||||
x: [batch_size, seq_len, nheads, headdim]
|
||||
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||
interleaved: defalut as False (Neox-style).
|
||||
...
|
||||
"""
|
||||
interleaved = not self.is_neox_style
|
||||
output = self.apply_rotary_emb_flash_attn(
|
||||
x, cos, sin, interleaved=interleaved
|
||||
).type_as(x)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
else:
|
||||
# Falling back to PyTorch native implementation.
|
||||
output = self.forward_native(x, cos, sin)
|
||||
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"is_neox_style={self.is_neox_style}"
|
||||
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||
return s
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .mrope import MRotaryEmbedding
|
||||
|
||||
|
||||
@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||
|
||||
|
||||
@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
|
||||
@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = torch.cat(
|
||||
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
|
||||
return processor
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
|
||||
output = output.to(orig_dtype)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
context_layer = self.attn(
|
||||
|
||||
@ -33,7 +33,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -89,52 +91,6 @@ logger = init_logger(__name__)
|
||||
# === Vision Transformer === #
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
output = self.attn(
|
||||
|
||||
@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -95,7 +98,7 @@ from .interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
|
||||
from .qwen2_vl import _create_qwen2vl_field_factory
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
|
||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
|
||||
@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
elif current_platform.is_rocm():
|
||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||
else:
|
||||
# For other platforms, use PyTorch fallback
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
)
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
q_embed = apply_rotary_emb(q, cos, sin)
|
||||
k_embed = apply_rotary_emb(k, cos, sin)
|
||||
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from typing import Annotated, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.activations import GELUActivation
|
||||
from transformers.modeling_outputs import (
|
||||
@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
dispatch_rotary_emb_function,
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@ -130,47 +130,6 @@ def smart_resize(
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
context_layer = self.attn(
|
||||
|
||||
@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||
@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (
|
||||
Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
qk_reshaped = einops.rearrange(
|
||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||
)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_reshaped,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
qk_rotated = qk_rotated.view(
|
||||
2,
|
||||
|
||||
@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
dispatch_rotary_emb_function,
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function(
|
||||
default=partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
)
|
||||
output = rotary_emb_function(t, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ within a vision language model."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
return patch_embeds
|
||||
|
||||
|
||||
# copy from flash_attn/layers/rotary.py
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
if is_flash_attn_backend and not current_platform.is_cuda():
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||
else:
|
||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_native
|
||||
|
||||
q_embed = apply_rotary_emb_func(q, cos, sin)
|
||||
k_embed = apply_rotary_emb_func(k, cos, sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user