diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e1dcd9870b6ca..b5e742c65c9f8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -194,8 +194,9 @@ class RocmPlatform(Platform): f" The selected backend, {selected_backend.name}," f"is not MLA type while requested for MLA backend.") - selected_backend = (_Backend.ROCM_FLASH if selected_backend - == _Backend.FLASH_ATTN else selected_backend) + if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: + selected_backend = _Backend.ROCM_FLASH + if envs.VLLM_USE_V1: logger.info("Using Triton Attention backend on V1 engine.") return ("vllm.v1.attention.backends."