diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index 4ec79e9eb6ba4..80158d4b7278c 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -36,6 +36,12 @@ def mock_on_gfx9(): @pytest.mark.parametrize( "env_vars, selected_backend, expected_backend_path", [ + # Test Case: Explicit FLEX_ATTENTION backend + ( + {}, + "FLEX_ATTENTION", + AttentionBackendEnum.FLEX_ATTENTION.get_path(), + ), # Test Case 1: Default (no env vars, no explicit backend) ( {}, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f3ec965bd0881..b0434b9642f07 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -262,6 +262,10 @@ class RocmPlatform(Platform): f"is not MLA type while requested for MLA backend." ) + if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: + logger.info("Using FlexAttention backend.") + return AttentionBackendEnum.FLEX_ATTENTION.get_path() + if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info("Using Triton Attention backend on V1 engine.") return AttentionBackendEnum.TRITON_ATTN.get_path()