From 700a5ad6c616358f42db7d9b55e8bc9caa140ca5 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 19 Dec 2025 02:04:19 +0800 Subject: [PATCH] [MM Encoder]: Migrate legacy ViT `MultiHeadAttention` to new `MMEncoderAttention` interface (#30684) Signed-off-by: Isotr0py --- tests/kernels/attention/test_attention.py | 5 +- tests/kernels/attention/test_mha_attn.py | 78 +++++++++-- tests/v1/tpu/test_mha_attn.py | 6 +- vllm/attention/layer.py | 132 ------------------ vllm/attention/layers/mm_encoder_attention.py | 90 ++++-------- vllm/attention/ops/vit_attn_wrappers.py | 53 +++++-- vllm/model_executor/models/aimv2.py | 4 +- vllm/model_executor/models/blip.py | 4 +- vllm/model_executor/models/clip.py | 11 +- vllm/model_executor/models/deepencoder.py | 4 +- vllm/model_executor/models/glm4v.py | 4 +- vllm/model_executor/models/hunyuan_vision.py | 4 +- .../models/idefics2_vision_model.py | 8 +- vllm/model_executor/models/intern_vit.py | 4 +- vllm/model_executor/models/interns1_vit.py | 8 +- vllm/model_executor/models/mllama4.py | 4 +- vllm/model_executor/models/molmo.py | 5 +- vllm/model_executor/models/siglip.py | 10 +- vllm/model_executor/models/step3_vl.py | 8 +- vllm/model_executor/models/whisper.py | 6 +- 20 files changed, 182 insertions(+), 266 deletions(-) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 1a7d5ce0ddc1e..96bdcf16d5689 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -9,7 +9,8 @@ import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.attention.layer import Attention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.platforms import current_platform from vllm.utils.mem_utils import get_max_shared_memory_bytes @@ -442,7 +443,7 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) +@pytest.mark.parametrize("attention_cls", [Attention, MMEncoderAttention]) def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: head_size = 64 scale = float(1.0 / (head_size**0.5)) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 639abdf6f0487..7405e4d41da94 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -3,16 +3,17 @@ """ Test: -* Tests for MultiHeadAttention layer +* Tests for MMEncoderAttention layer """ +import itertools from unittest.mock import patch import pytest import torch from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform @@ -42,35 +43,31 @@ def test_mha_attn_platform(device: str): if device == "cpu": with ( - patch("vllm.attention.layer.current_platform", CpuPlatform()), patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), ): - attn = MultiHeadAttention(16, 64, scale=1) + attn = MMEncoderAttention(16, 64, scale=1) assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA elif device == "hip": with ( - patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): - attn = MultiHeadAttention(16, 64, scale=1) + attn = MMEncoderAttention(16, 64, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention with ( - patch("vllm.attention.layer.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): - attn = MultiHeadAttention(16, 64, scale=1) + attn = MMEncoderAttention(16, 64, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - should use vLLM's FlashAttention with ( - patch("vllm.attention.layer.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): - attn = MultiHeadAttention(16, 72, scale=1) + attn = MMEncoderAttention(16, 72, scale=1) assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN @@ -94,6 +91,10 @@ def ref_attention( BATCH_SIZES = [1, 16] SEQ_LENS = [1] +VAR_SEQ_LENS = [ + [2, 2], + [2, 3, 4], +] NUM_HEADS = [1, 16] NUM_KV_HEADS = [1] HEAD_SIZES = [64, 80] @@ -130,7 +131,7 @@ def test_mha_attn_forward( k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention( + attn = MMEncoderAttention( num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads ) output = attn(q, k, v) @@ -151,3 +152,58 @@ def test_mha_attn_forward( scale=scale, ).reshape(batch_size, seq_len, num_heads * head_size) torch.testing.assert_close(output, ref_output) + + +@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_mha_attn_varlen_forward( + var_seq_len: list[int], + num_heads: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: str, +): + current_platform.seed_everything(0) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + q = torch.randn(1, sum(var_seq_len), num_heads, head_size) + k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size) + v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size) + cu_seqlens = torch.tensor( + [0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32 + ) + scale = 1.0 / head_size**0.5 + attn = MMEncoderAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) + output = attn( + q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len)) + ) + + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + if num_queries_per_kv > 1: + k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) + v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) + + ref_output = [] + for q_i, k_i, v_i in zip( + torch.split(q, var_seq_len, dim=1), + torch.split(k, var_seq_len, dim=1), + torch.split(v, var_seq_len, dim=1), + ): + output_i = ref_attention( + q_i, + k_i, + v_i, + scale=scale, + ) + ref_output.append(output_i) + ref_output = torch.cat(ref_output, dim=1) + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py index 5debdf85bea8d..84968dee6b60c 100644 --- a/tests/v1/tpu/test_mha_attn.py +++ b/tests/v1/tpu/test_mha_attn.py @@ -3,7 +3,7 @@ """ Test: -* Tests for MultiHeadAttention layer +* Tests for MMEncoderAttention layer """ import pytest @@ -12,7 +12,7 @@ import torch_xla import torch_xla.core import torch_xla.core.xla_model -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform @@ -69,7 +69,7 @@ def test_mha_attn_forward( k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention( + attn = MMEncoderAttention( num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads ) output = attn(q, k, v) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7ef77db8fbb5b..1d882eb87bfde 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,12 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -import functools from typing import cast import torch import torch.nn as nn -import torch.nn.functional as F import vllm.envs as envs from vllm.attention.backends.abstract import ( @@ -16,13 +14,10 @@ from vllm.attention.backends.abstract import ( MLAAttentionImpl, ) from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend from vllm.attention.selector import get_attn_backend -from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config -from vllm.config.multimodal import MultiModalConfig from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger @@ -36,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils.torch_utils import ( direct_register_custom_op, @@ -412,132 +406,6 @@ class Attention(nn.Module, AttentionLayerBase): ) -class MultiHeadAttention(nn.Module): - """Multi-headed attention without any cache, used for ViT.""" - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int | None = None, - # This has no effect, it is only here to make it easier to swap - # between Attention and MultiHeadAttention - prefix: str = "", - multimodal_config: MultiModalConfig | None = None, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.scale = scale - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.layer_name = prefix - - assert self.num_heads % self.num_kv_heads == 0, ( - f"num_heads ({self.num_heads}) is not " - f"divisible by num_kv_heads ({self.num_kv_heads})" - ) - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - - # Determine the attention backend - attn_backend_override = None - if multimodal_config is not None: - attn_backend_override = multimodal_config.mm_encoder_attn_backend - - self.attn_backend = get_vit_attn_backend( - head_size=head_size, - dtype=dtype, - attn_backend_override=attn_backend_override, - ) - - self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( - self.attn_backend, - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - - self.fa_version = None - if ( - self.attn_backend == AttentionBackendEnum.FLASH_ATTN - and current_platform.is_cuda() - ): - self.fa_version = get_flash_attn_version() - assert self._flash_attn_varlen_func is not None - self._flash_attn_varlen_func = functools.partial( - self._flash_attn_varlen_func, fa_version=self.fa_version - ) - - logger.info_once( - f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." - ) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - ) -> torch.Tensor: - """Input shape: - (batch_size x seq_len x hidden_size) or - (batch_size x seq_len x num_heads x head_size) - """ - bsz, q_len = query.size()[:2] - kv_len = key.size(1) - - query = query.view(bsz, q_len, self.num_heads, self.head_size) - key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) - value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) - - if (num_repeat := self.num_queries_per_kv) > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_repeat, dim=2) - value = torch.repeat_interleave(value, num_repeat, dim=2) - - if self.is_flash_attn_backend: - assert self._flash_attn_varlen_func is not None - cu_seqlens_q = torch.arange( - 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device - ) - cu_seqlens_k = torch.arange( - 0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device - ) - - out = self._flash_attn_varlen_func( - query.flatten(0, 1), - key.flatten(0, 1), - value.flatten(0, 1), - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_len, - max_seqlen_k=kv_len, - softmax_scale=self.scale, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) - out = out.transpose(1, 2) - elif self.attn_backend == AttentionBackendEnum.PALLAS: - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - from torch_xla.experimental.custom_kernel import flash_attention - - out = flash_attention(query, key, value, sm_scale=self.scale) - out = out.transpose(1, 2) - else: - # ViT attention hasn't supported this backend yet - raise NotImplementedError( - f"ViT attention hasn't supported {self.attn_backend} backend yet." - ) - - return out.reshape(bsz, q_len, -1) - - class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index 8b3dee1340b9f..25f54cc867b5a 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch @@ -19,27 +18,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend logger = init_logger(__name__) -def maybe_get_vit_flash_attn_backend( - attn_backend: AttentionBackendEnum | None, -) -> Callable | None: - # At this point, - # we already have the attn_backend, - # overriding logic is done in the platform-specific implementation. - # so we don't need to override backend here. - # Just return the attn_backend and flash_attn_varlen_func. - - if attn_backend == AttentionBackendEnum.FLASH_ATTN: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func - elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - flash_attn_varlen_func = None - - # if attn_backend is TORCH_SDPA, - # it will reach here and the flash_attn_varlen_func will be None. - return flash_attn_varlen_func - - @CustomOp.register("mm_encoder_attn") class MMEncoderAttention(CustomOp): """Multi-headed attention without any cache, used for multimodal encoder.""" @@ -98,21 +76,17 @@ class MMEncoderAttention(CustomOp): AttentionBackendEnum.ROCM_AITER_FA, } - self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( - self.attn_backend, + self._fa_version = ( + get_flash_attn_version() if self.is_flash_attn_backend else None ) - if self.is_flash_attn_backend: - assert self.flash_attn_varlen_func is not None - self._fa_version = get_flash_attn_version() - logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod def enabled(cls) -> bool: return True - def reshape_qkv_to_4d( + def maybe_reshape_qkv_to_4d( self, query: torch.Tensor, key: torch.Tensor, @@ -136,30 +110,6 @@ class MMEncoderAttention(CustomOp): return query, key, value - def reshape_qkv_to_3d( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - bsz: int, - q_len: int, - kv_len: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Reshape query, key, value to 3D tensors: - (batch_size * seq_len, num_heads, head_size) - """ - query = query.view(bsz * q_len, self.num_heads, self.head_size) - key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) - value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) - - if (num_repeat := self.num_queries_per_kv) > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_repeat, dim=1) - value = torch.repeat_interleave(value, num_repeat, dim=1) - - return query, key, value - def _forward_sdpa( self, query: torch.Tensor, @@ -167,13 +117,15 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: - # TODO(Isotr0py): Migrate MultiHeadAttention - assert cu_seqlens is not None - + """Input shape: + (batch_size x seq_len x hidden_size) or + (batch_size x seq_len x num_heads x head_size) + """ bsz, q_len = query.size()[:2] kv_len = key.size(1) + is_reshaped = query.dim() != 4 - query, key, value = self.reshape_qkv_to_4d( + query, key, value = self.maybe_reshape_qkv_to_4d( query, key, value, bsz, q_len, kv_len ) @@ -183,6 +135,8 @@ class MMEncoderAttention(CustomOp): v=value, cu_seqlens=cu_seqlens, ) + if is_reshaped: + output = output.view(bsz, q_len, -1) return output def _forward_fa( @@ -193,13 +147,21 @@ class MMEncoderAttention(CustomOp): cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: - assert self.flash_attn_varlen_func is not None, ( - "Flash attention function is not set." - ) - # # TODO(Isotr0py): Migrate MultiHeadAttention - assert cu_seqlens is not None and max_seqlen is not None + """Input shape: + (batch_size x seq_len x hidden_size) or + (batch_size x seq_len x num_heads x head_size) + """ + assert (cu_seqlens is not None and max_seqlen is not None) or ( + cu_seqlens is None and max_seqlen is None + ), "cu_seqlens and max_seqlen should be both set or both None." - bsz = query.shape[0] + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + is_reshaped = query.dim() != 4 + + query, key, value = self.maybe_reshape_qkv_to_4d( + query, key, value, bsz, q_len, kv_len + ) output = vit_flash_attn_wrapper( q=query, @@ -211,6 +173,8 @@ class MMEncoderAttention(CustomOp): is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), fa_version=self._fa_version, ) + if is_reshaped: + output = output.view(bsz, q_len, -1) return output def forward_native( diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index f555147bc055a..2204382a35e2a 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -24,11 +24,11 @@ def flash_attn_maxseqlen_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, fa_version: int | None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: kwargs = {} if is_rocm_aiter: @@ -38,6 +38,14 @@ def flash_attn_maxseqlen_wrapper( if not current_platform.is_rocm() and fa_version is not None: kwargs["fa_version"] = fa_version + + q_len = q.size(1) + if cu_seqlens is None: + cu_seqlens = torch.arange( + 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device + ) + max_seqlen = q_len if max_seqlen is None else max_seqlen.item() + q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, @@ -45,8 +53,8 @@ def flash_attn_maxseqlen_wrapper( v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen.item(), - max_seqlen_k=max_seqlen.item(), + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, dropout_p=0.0, causal=False, **kwargs, @@ -79,24 +87,42 @@ def vit_flash_attn_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, fa_version: int | None, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version + q, + k, + v, + batch_size, + is_rocm_aiter, + fa_version, + cu_seqlens, + max_seqlen, ) +def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Input shape: + (batch_size x seq_len x num_heads x head_size) + """ + q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) + output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + output = einops.rearrange(output, "b h s d -> b s h d ") + return output + + # TODO: Once we have a torch 2.10, we can use tensor slices # so we won't need to wrap this in custom ops def torch_sdpa_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: # Never remove the contiguous logic for ROCm # Without it, hallucinations occur with the backend @@ -105,6 +131,9 @@ def torch_sdpa_wrapper( k = k.contiguous() v = v.contiguous() + if cu_seqlens is None: + return apply_sdpa(q, k, v) + outputs = [] lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -112,11 +141,7 @@ def torch_sdpa_wrapper( k_chunks = torch.split(k, lens, dim=1) v_chunks = torch.split(v, lens, dim=1) for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - q_i, k_i, v_i = ( - einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = einops.rearrange(output_i, "b h s d -> b s h d ") + output_i = apply_sdpa(q_i, k_i, v_i) outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) return context_layer @@ -142,6 +167,6 @@ def vit_torch_sdpa_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 3d000f3ac3ab5..96ca27ad02504 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -8,7 +8,7 @@ from collections.abc import Iterable import torch import torch.nn as nn -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.model_executor.layers.activation import SiluAndMul @@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_heads_per_partition, self.head_dim, self.scale ) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index f31f99c0592b2..7387830b32bdc 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn from transformers import Blip2VisionConfig, BlipVisionConfig -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -122,7 +122,7 @@ class BlipAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_heads_per_partition, self.head_dim, self.scale ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 22f3ecad748e6..8e77b36e6feb5 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -14,7 +14,8 @@ from transformers import ( CLIPVisionConfig, ) -from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.attention.layer import Attention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -354,7 +355,7 @@ class CLIPAttention(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: type[Attention] | type[MultiHeadAttention], + attn_cls: type[Attention] | type[MMEncoderAttention], ) -> None: super().__init__() @@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: type[Attention] | type[MultiHeadAttention], + attn_cls: type[Attention] | type[MMEncoderAttention], ) -> None: super().__init__() self.self_attn = CLIPAttention( @@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module): num_hidden_layers_override: int | None = None, *, prefix: str = "", - attn_cls: type[Attention] | type[MultiHeadAttention], + attn_cls: type[Attention] | type[MMEncoderAttention], ) -> None: super().__init__() @@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", - attn_cls=MultiHeadAttention, + attn_cls=MMEncoderAttention, ) num_hidden_layers = config.num_hidden_layers diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index 8f1660891fcbf..045445d23b8f3 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -18,7 +18,7 @@ import torch.nn as nn import torch.nn.functional as F from transformers import CLIPVisionConfig -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", - attn_cls=MultiHeadAttention, + attn_cls=MMEncoderAttention, ) num_hidden_layers = config.num_hidden_layers diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index ec5af94e297c1..453a7812a1748 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module): prefix=f"{prefix}.dense", ) - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_heads_per_rank, self.head_dim, self.scale ) self.output_dropout = torch.nn.Dropout(config.dropout_prob) diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index be084f4ee0f8e..6fc56094af650 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state @@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module): ) self.scale = self.hidden_size_per_attention_head**-0.5 - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, self.scale, diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 06b8468e18db9..ee6ca5eacb176 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import ( Idefics2VisionConfig, ) -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module): prefix=f"{prefix}.out_proj", disable_tp=use_data_parallel, ) - # Use unified MultiHeadAttention with Flash Attention support - self.attn = MultiHeadAttention( + # Use unified MMEncoderAttention with Flash Attention support + self.attn = MMEncoderAttention( self.num_heads_per_partition, self.head_dim, self.scale ) @@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module): ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) - # Use unified MultiHeadAttention implementation + # Use unified MMEncoderAttention implementation out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 61aeafc2ab436..5f7ba838aa3d9 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -15,7 +15,7 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module): disable_tp=use_data_parallel, ) - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_heads_per_partition, self.head_dim, self.scale ) diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index cb0414bbc95a8..a16857d613226 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -14,7 +14,7 @@ import torch.nn as nn from transformers import PretrainedConfig from transformers.utils import torch_int -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm @@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module): self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) - # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) + # Use unified MMEncoderAttention with automatic backend selection + self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) def forward(self, x: torch.Tensor) -> torch.Tensor: """x shape: (B, N, C)""" @@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - # Use unified MultiHeadAttention with automatic backend selection + # Use unified MMEncoderAttention with automatic backend selection x = self.attn(q, k, v) x = self.projection_layer(x) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index fe963cc6644fb..886d5151e43ff 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import ( get_best_fit, ) -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module): self.attention_dropout = config.attention_dropout self.scaling = self.head_dim**-0.5 - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_local_heads, self.head_dim, self.scaling ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 71c6b1aa2e814..9c741e1f5071f 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput -from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.attention.layer import Attention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module): ) self.scale = self.head_dim**-0.5 - self.attn = MultiHeadAttention( + self.attn = MMEncoderAttention( self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 2600dc1c9f79c..799afc7ca2e51 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -16,8 +16,8 @@ from transformers import ( SiglipVisionConfig, ) -from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -379,7 +379,7 @@ class SiglipAttention(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], ) -> None: super().__init__() @@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], ) -> None: super().__init__() @@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module): num_hidden_layers_override: int | None = None, *, prefix: str = "", - attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], ) -> None: super().__init__() @@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", - attn_cls=MultiHeadAttention, + attn_cls=MMEncoderAttention, ) num_hidden_layers = config.num_hidden_layers diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index e5038e56a2708..3c965721b9dae 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -15,7 +15,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType -from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module): disable_tp=use_data_parallel, ) - # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) + # Use unified MMEncoderAttention with automatic backend selection + self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) def forward( self, @@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module): qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - # Use unified MultiHeadAttention with automatic backend selection + # Use unified MMEncoderAttention with automatic backend selection attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index b513e3513b2e2..f5a1e75d99617 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -16,9 +16,9 @@ from transformers import ( ) from transformers.models.whisper.modeling_whisper import sinusoids -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.attention.layer import Attention, AttentionType from vllm.attention.layers.cross_attention import CrossAttention +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size @@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema): ] -class WhisperEncoderAttention(MultiHeadAttention): +class WhisperEncoderAttention(MMEncoderAttention): """Multi-headed attention for Whisper encoder with 2D tensor support.""" def forward(