mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 12:44:04 +08:00
[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
62be3670cb
commit
700a5ad6c6
@ -9,7 +9,8 @@ import torch
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||
|
||||
@ -442,7 +443,7 @@ def ref_multi_query_kv_attention(
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
||||
@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
|
||||
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||
head_size = 64
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
@ -3,16 +3,17 @@
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
* Tests for MMEncoderAttention layer
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str):
|
||||
|
||||
if device == "cpu":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CpuPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
attn = MMEncoderAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
attn = MMEncoderAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
attn = MMEncoderAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
attn = MMEncoderAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
@ -94,6 +91,10 @@ def ref_attention(
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
SEQ_LENS = [1]
|
||||
VAR_SEQ_LENS = [
|
||||
[2, 2],
|
||||
[2, 3, 4],
|
||||
]
|
||||
NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
@ -130,7 +131,7 @@ def test_mha_attn_forward(
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(
|
||||
attn = MMEncoderAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
@ -151,3 +152,58 @@ def test_mha_attn_forward(
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
torch.testing.assert_close(output, ref_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_varlen_forward(
|
||||
var_seq_len: list[int],
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
|
||||
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
|
||||
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
|
||||
cu_seqlens = torch.tensor(
|
||||
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
|
||||
)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MMEncoderAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(
|
||||
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
|
||||
)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
if num_queries_per_kv > 1:
|
||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||
|
||||
ref_output = []
|
||||
for q_i, k_i, v_i in zip(
|
||||
torch.split(q, var_seq_len, dim=1),
|
||||
torch.split(k, var_seq_len, dim=1),
|
||||
torch.split(v, var_seq_len, dim=1),
|
||||
):
|
||||
output_i = ref_attention(
|
||||
q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
scale=scale,
|
||||
)
|
||||
ref_output.append(output_i)
|
||||
ref_output = torch.cat(ref_output, dim=1)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
* Tests for MMEncoderAttention layer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@ -12,7 +12,7 @@ import torch_xla
|
||||
import torch_xla.core
|
||||
import torch_xla.core.xla_model
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -69,7 +69,7 @@ def test_mha_attn_forward(
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(
|
||||
attn = MMEncoderAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
|
||||
@ -2,12 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
|
||||
import functools
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
# This has no effect, it is only here to make it easier to swap
|
||||
# between Attention and MultiHeadAttention
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Determine the attention backend
|
||||
attn_backend_override = None
|
||||
if multimodal_config is not None:
|
||||
attn_backend_override = multimodal_config.mm_encoder_attn_backend
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.fa_version = None
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
self.fa_version = get_flash_attn_version()
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
self._flash_attn_varlen_func = functools.partial(
|
||||
self._flash_attn_varlen_func, fa_version=self.fa_version
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
|
||||
)
|
||||
cu_seqlens_k = torch.arange(
|
||||
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
|
||||
)
|
||||
|
||||
out = self._flash_attn_varlen_func(
|
||||
query.flatten(0, 1),
|
||||
key.flatten(0, 1),
|
||||
value.flatten(0, 1),
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=kv_len,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == AttentionBackendEnum.PALLAS:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
# ViT attention hasn't supported this backend yet
|
||||
raise NotImplementedError(
|
||||
f"ViT attention hasn't supported {self.attn_backend} backend yet."
|
||||
)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
"""Multi-Head Latent Attention layer.
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: AttentionBackendEnum | None,
|
||||
) -> Callable | None:
|
||||
# At this point,
|
||||
# we already have the attn_backend,
|
||||
# overriding logic is done in the platform-specific implementation.
|
||||
# so we don't need to override backend here.
|
||||
# Just return the attn_backend and flash_attn_varlen_func.
|
||||
|
||||
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
# if attn_backend is TORCH_SDPA,
|
||||
# it will reach here and the flash_attn_varlen_func will be None.
|
||||
return flash_attn_varlen_func
|
||||
|
||||
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
@ -98,21 +76,17 @@ class MMEncoderAttention(CustomOp):
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
assert self.flash_attn_varlen_func is not None
|
||||
self._fa_version = get_flash_attn_version()
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def reshape_qkv_to_4d(
|
||||
def maybe_reshape_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -136,30 +110,6 @@ class MMEncoderAttention(CustomOp):
|
||||
|
||||
return query, key, value
|
||||
|
||||
def reshape_qkv_to_3d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 3D tensors:
|
||||
(batch_size * seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=1)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=1)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -167,13 +117,15 @@ class MMEncoderAttention(CustomOp):
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None
|
||||
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.reshape_qkv_to_4d(
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
|
||||
@ -183,6 +135,8 @@ class MMEncoderAttention(CustomOp):
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.view(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
@ -193,13 +147,21 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
assert self.flash_attn_varlen_func is not None, (
|
||||
"Flash attention function is not set."
|
||||
)
|
||||
# # TODO(Isotr0py): Migrate MultiHeadAttention
|
||||
assert cu_seqlens is not None and max_seqlen is not None
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz = query.shape[0]
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
@ -211,6 +173,8 @@ class MMEncoderAttention(CustomOp):
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.view(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
|
||||
@ -24,11 +24,11 @@ def flash_attn_maxseqlen_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if is_rocm_aiter:
|
||||
@ -38,6 +38,14 @@ def flash_attn_maxseqlen_wrapper(
|
||||
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
|
||||
q_len = q.size(1)
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
|
||||
)
|
||||
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
|
||||
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
@ -45,8 +53,8 @@ def flash_attn_maxseqlen_wrapper(
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen.item(),
|
||||
max_seqlen_k=max_seqlen.item(),
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
**kwargs,
|
||||
@ -79,24 +87,42 @@ def vit_flash_attn_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
batch_size,
|
||||
is_rocm_aiter,
|
||||
fa_version,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
|
||||
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Input shape:
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
|
||||
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Once we have a torch 2.10, we can use tensor slices
|
||||
# so we won't need to wrap this in custom ops
|
||||
def torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
@ -105,6 +131,9 @@ def torch_sdpa_wrapper(
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
if cu_seqlens is None:
|
||||
return apply_sdpa(q, k, v)
|
||||
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
@ -112,11 +141,7 @@ def torch_sdpa_wrapper(
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
q_i, k_i, v_i = (
|
||||
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
|
||||
)
|
||||
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
|
||||
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
|
||||
output_i = apply_sdpa(q_i, k_i, v_i)
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
return context_layer
|
||||
@ -142,6 +167,6 @@ def vit_torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@ -14,7 +14,8 @@ from transformers import (
|
||||
CLIPVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = CLIPAttention(
|
||||
@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_rank, self.head_dim, self.scale
|
||||
)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
self.scale,
|
||||
|
||||
@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2VisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
# Use unified MultiHeadAttention with Flash Attention support
|
||||
self.attn = MultiHeadAttention(
|
||||
# Use unified MMEncoderAttention with Flash Attention support
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
|
||||
# Use unified MultiHeadAttention implementation
|
||||
# Use unified MMEncoderAttention implementation
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
|
||||
|
||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""x shape: (B, N, C)"""
|
||||
@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
x = self.attn(q, k, v)
|
||||
|
||||
x = self.projection_layer(x)
|
||||
|
||||
@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_local_heads, self.head_dim, self.scaling
|
||||
)
|
||||
|
||||
|
||||
@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
||||
)
|
||||
|
||||
|
||||
@ -16,8 +16,8 @@ from transformers import (
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@ -15,7 +15,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
|
||||
@ -16,9 +16,9 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoderAttention(MultiHeadAttention):
|
||||
class WhisperEncoderAttention(MMEncoderAttention):
|
||||
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user