vllm/vllm/fa_utils.py
Mickaël Seznec a597a57595
[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
2025-03-20 01:14:20 -04:00

43 lines
1.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_flash_attn_version() -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = 3 if (device_capability.major == 9
and is_fa_version_supported(3)) else 2
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None