From fd8afdf38dad8bf7ccc4e7fcc3d4aaa4d6d9e0d8 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Wed, 17 Dec 2025 20:27:37 -0600 Subject: [PATCH] [ROCm][CI] Reduce Flakiness For test_async_scheduling Using ROCM_ATTN With FP32 (#30811) Signed-off-by: Micah Williamson --- tests/v1/e2e/test_async_scheduling.py | 12 ++---------- vllm/v1/attention/backends/rocm_attn.py | 6 +++++- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 61e56c079a3b5..6447a33838d75 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -148,7 +148,7 @@ def run_tests( # Use TRITON_ATTN for spec decoding test for consistency attention_config = {"backend": "TRITON_ATTN"} else: - attention_config = {"backend": "ROCM_AITER_FA"} + attention_config = {"backend": "ROCM_ATTN"} else: attention_config = {"backend": "FLEX_ATTENTION"} @@ -284,14 +284,6 @@ def run_test( print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) - # On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for - # spec decoding test (TRITON_ATTN) for better precision. - # On others: always use float32. - if current_platform.is_rocm() and not is_testing_with_spec_decoding: - dtype = "float16" - else: - dtype = "float32" - with VllmRunner( model, max_model_len=512, @@ -301,7 +293,7 @@ def run_test( # enforce_eager=True, async_scheduling=async_scheduling, distributed_executor_backend=executor, - dtype=dtype, + dtype="float32", speculative_config=spec_config, disable_log_stats=False, attention_config=attention_config, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e231c600cba7a..3701373f33315 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod def get_supported_head_sizes(cls) -> list[int]: