diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index ef6261fa6d9b4..f2d01099097a5 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade + GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn @@ -64,4 +64,4 @@ install( DESTINATION vllm_flash_attn COMPONENT _vllm_fa3_C FILES_MATCHING PATTERN "*.py" -) \ No newline at end of file +) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index baf01c9263d4f..4374b54222540 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -595,7 +595,7 @@ def get_flash_attn_version(): # if hopper default to FA3, otherwise stick to FA2 for now # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both - if current_platform.get_device_capability()[0] >= 9: + if current_platform.get_device_capability()[0] == 9: fa_version = 3 if is_fa_version_supported(3) else 2 else: fa_version = 2 @@ -603,6 +603,11 @@ def get_flash_attn_version(): if envs.VLLM_FLASH_ATTN_VERSION is not None: assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] fa_version = envs.VLLM_FLASH_ATTN_VERSION + if (current_platform.get_device_capability()[0] == 10 + and envs.VLLM_FLASH_ATTN_VERSION == 3): + logger.warning("Cannot use FA version 3 on Blackwell platform", + "defaulting to FA version 2.") + fa_version = 2 if not is_fa_version_supported(fa_version): logger.error("Cannot use FA version %d is not supported due to %s",