mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 21:45:02 +08:00
[Bugfix] detect alibi and revert to FA2 (#15231)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
47195057e9
commit
0032903a5b
@ -630,7 +630,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.sliding_window = ((sliding_window - 1,
|
self.sliding_window = ((sliding_window - 1,
|
||||||
0) if sliding_window is not None else (-1, -1))
|
0) if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||||
|
requires_alibi=self.alibi_slopes is not None)
|
||||||
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
and self.vllm_flash_attn_version != 3):
|
and self.vllm_flash_attn_version != 3):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_flash_attn_version() -> Optional[int]:
|
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||||
# import here to avoid circular dependencies
|
# import here to avoid circular dependencies
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
try:
|
try:
|
||||||
@ -28,7 +28,13 @@ def get_flash_attn_version() -> Optional[int]:
|
|||||||
|
|
||||||
# 3. fallback for unsupported combinations
|
# 3. fallback for unsupported combinations
|
||||||
if device_capability.major == 10 and fa_version == 3:
|
if device_capability.major == 10 and fa_version == 3:
|
||||||
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
logger.warning_once(
|
||||||
|
"Cannot use FA version 3 on Blackwell platform "
|
||||||
|
"defaulting to FA version 2.")
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if requires_alibi and fa_version == 3:
|
||||||
|
logger.warning_once("Cannot use FA version 3 with ALiBi, "
|
||||||
"defaulting to FA version 2.")
|
"defaulting to FA version 2.")
|
||||||
fa_version = 2
|
fa_version = 2
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user