diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b95287906c1fe..8a19a5c6ff9ba 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -294,8 +294,13 @@ class RocmPlatform(Platform): attn_selector_config.attn_type is not None and attn_selector_config.attn_type == AttentionType.ENCODER_ONLY ): - logger.info("Using FlexAttention backend.") - return AttentionBackendEnum.FLEX_ATTENTION.get_path() + # Use generic FlashAttention for encoder-only models + # ROCM_AITER_FA doesn't support encoder-only (causal-only limitation) + # Generic FLASH_ATTN supports all attention types including ENCODER_ONLY + logger.info( + "Using FlashAttention backend for encoder-only model on ROCm." + ) + return AttentionBackendEnum.FLASH_ATTN.get_path() # Default: Triton Unified Attention logger.info("Using Triton Attention backend.")