[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-19 02:04:19 +08:00 committed by GitHub
parent 62be3670cb
commit 700a5ad6c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 182 additions and 266 deletions

View File

@ -9,7 +9,8 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_utils import get_max_shared_memory_bytes from vllm.utils.mem_utils import get_max_shared_memory_bytes
@ -442,7 +443,7 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) @pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64 head_size = 64
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))

View File

@ -3,16 +3,17 @@
""" """
Test: Test:
* Tests for MultiHeadAttention layer * Tests for MMEncoderAttention layer
""" """
import itertools
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str):
if device == "cpu": if device == "cpu":
with ( with (
patch("vllm.attention.layer.current_platform", CpuPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip": elif device == "hip":
with ( with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
with ( with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
with ( with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
@ -94,6 +91,10 @@ def ref_attention(
BATCH_SIZES = [1, 16] BATCH_SIZES = [1, 16]
SEQ_LENS = [1] SEQ_LENS = [1]
VAR_SEQ_LENS = [
[2, 2],
[2, 3, 4],
]
NUM_HEADS = [1, 16] NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1] NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80] HEAD_SIZES = [64, 80]
@ -130,7 +131,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5
attn = MultiHeadAttention( attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
) )
output = attn(q, k, v) output = attn(q, k, v)
@ -151,3 +152,58 @@ def test_mha_attn_forward(
scale=scale, scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size) ).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output) torch.testing.assert_close(output, ref_output)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
var_seq_len: list[int],
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
cu_seqlens = torch.tensor(
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = []
for q_i, k_i, v_i in zip(
torch.split(q, var_seq_len, dim=1),
torch.split(k, var_seq_len, dim=1),
torch.split(v, var_seq_len, dim=1),
):
output_i = ref_attention(
q_i,
k_i,
v_i,
scale=scale,
)
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)

View File

@ -3,7 +3,7 @@
""" """
Test: Test:
* Tests for MultiHeadAttention layer * Tests for MMEncoderAttention layer
""" """
import pytest import pytest
@ -12,7 +12,7 @@ import torch_xla
import torch_xla.core import torch_xla.core
import torch_xla.core.xla_model import torch_xla.core.xla_model
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -69,7 +69,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device)
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5
attn = MultiHeadAttention( attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
) )
output = attn(q, k, v) output = attn(q, k, v)

View File

@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
import functools
from typing import cast from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.multimodal import MultiModalConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase):
) )
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.is_flash_attn_backend:
assert self._flash_attn_varlen_func is not None
cu_seqlens_q = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
)
cu_seqlens_k = torch.arange(
0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} backend yet."
)
return out.reshape(bsz, q_len, -1)
class MLAAttention(nn.Module, AttentionLayerBase): class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer. """Multi-Head Latent Attention layer.

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
@ -19,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn") @CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp): class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder.""" """Multi-headed attention without any cache, used for multimodal encoder."""
@ -98,21 +76,17 @@ class MMEncoderAttention(CustomOp):
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( self._fa_version = (
self.attn_backend, get_flash_attn_version() if self.is_flash_attn_backend else None
) )
if self.is_flash_attn_backend:
assert self.flash_attn_varlen_func is not None
self._fa_version = get_flash_attn_version()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod @classmethod
def enabled(cls) -> bool: def enabled(cls) -> bool:
return True return True
def reshape_qkv_to_4d( def maybe_reshape_qkv_to_4d(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
@ -136,30 +110,6 @@ class MMEncoderAttention(CustomOp):
return query, key, value return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa( def _forward_sdpa(
self, self,
query: torch.Tensor, query: torch.Tensor,
@ -167,13 +117,15 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor, value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention """Input shape:
assert cu_seqlens is not None (batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2] bsz, q_len = query.size()[:2]
kv_len = key.size(1) kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.reshape_qkv_to_4d( query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len query, key, value, bsz, q_len, kv_len
) )
@ -183,6 +135,8 @@ class MMEncoderAttention(CustomOp):
v=value, v=value,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output return output
def _forward_fa( def _forward_fa(
@ -193,13 +147,21 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, ( """Input shape:
"Flash attention function is not set." (batch_size x seq_len x hidden_size) or
) (batch_size x seq_len x num_heads x head_size)
# # TODO(Isotr0py): Migrate MultiHeadAttention """
assert cu_seqlens is not None and max_seqlen is not None assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz = query.shape[0] bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_flash_attn_wrapper( output = vit_flash_attn_wrapper(
q=query, q=query,
@ -211,6 +173,8 @@ class MMEncoderAttention(CustomOp):
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version, fa_version=self._fa_version,
) )
if is_reshaped:
output = output.view(bsz, q_len, -1)
return output return output
def forward_native( def forward_native(

View File

@ -24,11 +24,11 @@ def flash_attn_maxseqlen_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None, fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {} kwargs = {}
if is_rocm_aiter: if is_rocm_aiter:
@ -38,6 +38,14 @@ def flash_attn_maxseqlen_wrapper(
if not current_platform.is_rocm() and fa_version is not None: if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
@ -45,8 +53,8 @@ def flash_attn_maxseqlen_wrapper(
v, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen.item(), max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen.item(), max_seqlen_k=max_seqlen,
dropout_p=0.0, dropout_p=0.0,
causal=False, causal=False,
**kwargs, **kwargs,
@ -79,24 +87,42 @@ def vit_flash_attn_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int | None, fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version q,
k,
v,
batch_size,
is_rocm_aiter,
fa_version,
cu_seqlens,
max_seqlen,
) )
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)
"""
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
output = einops.rearrange(output, "b h s d -> b s h d ")
return output
# TODO: Once we have a torch 2.10, we can use tensor slices # TODO: Once we have a torch 2.10, we can use tensor slices
# so we won't need to wrap this in custom ops # so we won't need to wrap this in custom ops
def torch_sdpa_wrapper( def torch_sdpa_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm # Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend # Without it, hallucinations occur with the backend
@ -105,6 +131,9 @@ def torch_sdpa_wrapper(
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
if cu_seqlens is None:
return apply_sdpa(q, k, v)
outputs = [] outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
@ -112,11 +141,7 @@ def torch_sdpa_wrapper(
k_chunks = torch.split(k, lens, dim=1) k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1) v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
q_i, k_i, v_i = ( output_i = apply_sdpa(q_i, k_i, v_i)
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
)
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
return context_layer return context_layer
@ -142,6 +167,6 @@ def vit_torch_sdpa_wrapper(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

View File

@ -8,7 +8,7 @@ from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )

View File

@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig, CLIPVisionConfig,
) )
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = CLIPAttention( self.self_attn = CLIPAttention(
@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention], attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers

View File

@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers

View File

@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale self.num_heads_per_rank, self.head_dim, self.scale
) )
self.output_dropout = torch.nn.Dropout(config.dropout_prob) self.output_dropout = torch.nn.Dropout(config.dropout_prob)

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
) )
self.scale = self.hidden_size_per_attention_head**-0.5 self.scale = self.hidden_size_per_attention_head**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.scale, self.scale,

View File

@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig, Idefics2VisionConfig,
) )
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Use unified MultiHeadAttention with Flash Attention support # Use unified MMEncoderAttention with Flash Attention support
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )
@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1) query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation # Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states) out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out) attn_output, _ = self.out_proj(out)
return attn_output return attn_output

View File

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale self.num_heads_per_partition, self.head_dim, self.scale
) )

View File

@ -14,7 +14,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)""" """x shape: (B, N, C)"""
@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
k = self.k_norm(k) k = self.k_norm(k)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
x = self.attn(q, k, v) x = self.attn(q, k, v)
x = self.projection_layer(x) x = self.projection_layer(x)

View File

@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit, get_best_fit,
) )
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling self.num_local_heads, self.head_dim, self.scaling
) )

View File

@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import Attention, MultiHeadAttention from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention( self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
) )

View File

@ -16,8 +16,8 @@ from transformers import (
SiglipVisionConfig, SiglipVisionConfig,
) )
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None, num_hidden_layers_override: int | None = None,
*, *,
prefix: str = "", prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention, attn_cls=MMEncoderAttention,
) )
num_hidden_layers = config.num_hidden_layers num_hidden_layers = config.num_hidden_layers

View File

@ -15,7 +15,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward( def forward(
self, self,
@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MMEncoderAttention with automatic backend selection
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)

View File

@ -16,9 +16,9 @@ from transformers import (
) )
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
] ]
class WhisperEncoderAttention(MultiHeadAttention): class WhisperEncoderAttention(MMEncoderAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support.""" """Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward( def forward(