[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-19 02:04:19 +08:00 committed by GitHub
parent 62be3670cb
commit 700a5ad6c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 182 additions and 266 deletions

View File

@ -9,7 +9,8 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes
@ -442,7 +443,7 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
scale = float(1.0 / (head_size**0.5))

View File

@ -3,16 +3,17 @@
"""
Test:
* Tests for MultiHeadAttention layer
* Tests for MMEncoderAttention layer
"""
import itertools
from unittest.mock import patch
import pytest
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str):
if device == "cpu":
with (
patch("vllm.attention.layer.current_platform", CpuPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip":
with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 72, scale=1)
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
@ -94,6 +91,10 @@ def ref_attention(
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
VAR_SEQ_LENS = [
[2, 2],
[2, 3, 4],
]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
@ -130,7 +131,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
@ -151,3 +152,58 @@ def test_mha_attn_forward(
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
var_seq_len: list[int],
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
cu_seqlens = torch.tensor(
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = []
for q_i, k_i, v_i in zip(
torch.split(q, var_seq_len, dim=1),
torch.split(k, var_seq_len, dim=1),
torch.split(v, var_seq_len, dim=1),
):
output_i = ref_attention(
q_i,
k_i,
v_i,
scale=scale,
)
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)

View File

@ -3,7 +3,7 @@
"""
Test:
* Tests for MultiHeadAttention layer
* Tests for MMEncoderAttention layer
"""
import pytest
@ -12,7 +12,7 @@ import torch_xla
import torch_xla.core
import torch_xla.core.xla_model
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
@ -69,7 +69,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)

View File

@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
import functools
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention.backends.abstract import (
@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
)
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
@ -19,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
@ -98,21 +76,17 @@ class MMEncoderAttention(CustomOp):
AttentionBackendEnum.ROCM_AITER_FA,
}
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
)
if self.is_flash_attn_backend:
assert self.flash_attn_varlen_func is not None
self._fa_version = get_flash_attn_version()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def reshape_qkv_to_4d(
def maybe_reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
@ -136,30 +110,6 @@ class MMEncoderAttention(CustomOp):
return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
@ -167,13 +117,15 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.reshape_qkv_to_4d(
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
@ -183,6 +135,8 @@ class MMEncoderAttention(CustomOp):
v=value,
cu_seqlens=cu_seqlens,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output
def _forward_fa(
@ -193,13 +147,21 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz = query.shape[0]
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_flash_attn_wrapper(
q=query,
@ -211,6 +173,8 @@ class MMEncoderAttention(CustomOp):
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
)
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output
def forward_native(

View File

@ -24,11 +24,11 @@ def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
kwargs = {}
if is_rocm_aiter:
@ -38,6 +38,14 @@ def flash_attn_maxseqlen_wrapper(
if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
@ -45,8 +53,8 @@ def flash_attn_maxseqlen_wrapper(
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(),
max_seqlen_k=max_seqlen.item(),
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
**kwargs,
@ -79,24 +87,42 @@ def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
q,
k,
v,
batch_size,
is_rocm_aiter,
fa_version,
cu_seqlens,
max_seqlen,
)
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)
"""
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
output = einops.rearrange(output, "b h s d -> b s h d ")
return output
# TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops
def torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
@ -105,6 +131,9 @@ def torch_sdpa_wrapper(
k = k.contiguous()
v = v.contiguous()
if cu_seqlens is None:
return apply_sdpa(q, k, v)
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
@ -112,11 +141,7 @@ def torch_sdpa_wrapper(
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = (
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
output_i = apply_sdpa(q_i, k_i, v_i)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
return context_layer
@ -142,6 +167,6 @@ def vit_torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

View File

@ -8,7 +8,7 @@ from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig,
)
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
)
self.scale = self.hidden_size_per_attention_head**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
self.scale,

View File

@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(
# Use unified MMEncoderAttention with Flash Attention support
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation
# Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
disable_tp=use_data_parallel,
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@ -14,7 +14,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)"""
@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.projection_layer(x)

View File

@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling
)

View File

@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
)

View File

@ -16,8 +16,8 @@ from transformers import (
SiglipVisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@ -15,7 +15,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(
self,
@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output)

View File

@ -16,9 +16,9 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
]
class WhisperEncoderAttention(MultiHeadAttention):
class WhisperEncoderAttention(MMEncoderAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(