mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 05:25:02 +08:00
[Hardware] Update the flash attn tag to support Blackwell (#14244)
This commit is contained in:
parent
5ee10e990d
commit
ed6ea06577
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
|||||||
@ -595,7 +595,7 @@ def get_flash_attn_version():
|
|||||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||||
# use FA3 as default for both
|
# use FA3 as default for both
|
||||||
if current_platform.get_device_capability()[0] >= 9:
|
if current_platform.get_device_capability()[0] == 9:
|
||||||
fa_version = 3 if is_fa_version_supported(3) else 2
|
fa_version = 3 if is_fa_version_supported(3) else 2
|
||||||
else:
|
else:
|
||||||
fa_version = 2
|
fa_version = 2
|
||||||
@ -603,6 +603,11 @@ def get_flash_attn_version():
|
|||||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||||
|
if (current_platform.get_device_capability()[0] == 10
|
||||||
|
and envs.VLLM_FLASH_ATTN_VERSION == 3):
|
||||||
|
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
||||||
|
"defaulting to FA version 2.")
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
if not is_fa_version_supported(fa_version):
|
if not is_fa_version_supported(fa_version):
|
||||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user