[Bugfix] Pass FA version in MultiHeadAttention (#30575)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-12 19:02:11 -05:00 committed by GitHub
parent 08f8a5627e
commit 86a3261525
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
import functools
from collections.abc import Callable
from typing import cast
@ -17,6 +18,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
@ -524,6 +526,14 @@ class MultiHeadAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA,
}
self.fa_version = None
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)