mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 06:07:19 +08:00
[Bugfix] Pass FA version in MultiHeadAttention (#30575)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
08f8a5627e
commit
86a3261525
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
@ -17,6 +18,7 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -524,6 +526,14 @@ class MultiHeadAttention(nn.Module):
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.fa_version = None
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
self.fa_version = get_flash_attn_version()
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
self._flash_attn_varlen_func = functools.partial(
|
||||
self._flash_attn_varlen_func, fa_version=self.fa_version
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user