From 86a3261525858bca5d4a234691fae0496d6fed99 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 12 Dec 2025 19:02:11 -0500 Subject: [PATCH] [Bugfix] Pass FA version in `MultiHeadAttention` (#30575) Signed-off-by: Matthew Bonanni --- vllm/attention/layer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c77fc0fad0038..c095b94518143 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" +import functools from collections.abc import Callable from typing import cast @@ -17,6 +18,7 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config @@ -524,6 +526,14 @@ class MultiHeadAttention(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, } + self.fa_version = None + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: + self.fa_version = get_flash_attn_version() + assert self._flash_attn_varlen_func is not None + self._flash_attn_varlen_func = functools.partial( + self._flash_attn_varlen_func, fa_version=self.fa_version + ) + logger.info_once( f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." )