From e087fbc393055fb69e9acf71fa124be0190498ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Wed, 17 Dec 2025 00:54:45 +0100 Subject: [PATCH] [MM] Pass FA version in ViT Attn (#30756) Signed-off-by: NickLucche Co-authored-by: Cyrus Leung --- vllm/attention/layers/mm_encoder_attention.py | 6 ++++++ vllm/attention/ops/vit_attn_wrappers.py | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index c9107ebcab856..8b3dee1340b9f 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, ) +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp): self.attn_backend, ) + if self.is_flash_attn_backend: + assert self.flash_attn_varlen_func is not None + self._fa_version = get_flash_attn_version() + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod @@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp): max_seqlen=max_seqlen, batch_size=bsz, is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), + fa_version=self._fa_version, ) return output diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 892c4209c01e0..5a74e1310133d 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -28,11 +28,15 @@ def flash_attn_maxseqlen_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: + kwargs = {} if is_rocm_aiter: from aiter import flash_attn_varlen_func else: from vllm.attention.utils.fa_utils import flash_attn_varlen_func + + kwargs["fa_version"] = fa_version q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, @@ -44,6 +48,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen_k=max_seqlen.item(), dropout_p=0.0, causal=False, + **kwargs, ) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) return context_layer @@ -57,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: return torch.empty_like(q) @@ -76,9 +82,10 @@ def vit_flash_attn_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, + fa_version: int, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter + q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version )