diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 25601011491f..95d3fa74e325 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -276,6 +276,9 @@ class RocmPlatform(Platform): ) if envs.VLLM_USE_V1: + if selected_backend == _Backend.FLEX_ATTENTION: + logger.info("Using FlexAttention backend on V1 engine.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() ) or selected_backend == _Backend.ROCM_AITER_FA: