diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 61e56c079a3b5..6447a33838d75 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -148,7 +148,7 @@ def run_tests( # Use TRITON_ATTN for spec decoding test for consistency attention_config = {"backend": "TRITON_ATTN"} else: - attention_config = {"backend": "ROCM_AITER_FA"} + attention_config = {"backend": "ROCM_ATTN"} else: attention_config = {"backend": "FLEX_ATTENTION"} @@ -284,14 +284,6 @@ def run_test( print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) - # On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for - # spec decoding test (TRITON_ATTN) for better precision. - # On others: always use float32. - if current_platform.is_rocm() and not is_testing_with_spec_decoding: - dtype = "float16" - else: - dtype = "float32" - with VllmRunner( model, max_model_len=512, @@ -301,7 +293,7 @@ def run_test( # enforce_eager=True, async_scheduling=async_scheduling, distributed_executor_backend=executor, - dtype=dtype, + dtype="float32", speculative_config=spec_config, disable_log_stats=False, attention_config=attention_config, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e231c600cba7a..3701373f33315 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod def get_supported_head_sizes(cls) -> list[int]: