From c014de1ec777554d2954655bd564493476d92061 Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Tue, 2 Dec 2025 16:54:36 -0600 Subject: [PATCH] [ROCm][CI] Fix test_cudagraph_mode.py Failure For AMD CI (#29808) Signed-off-by: Micah Williamson --- tests/v1/cudagraph/test_cudagraph_mode.py | 40 ++++++++--------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 12621d493e54..b1895e83b8b3 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte # test cudagraph_mode with different compilation mode. # (backend_name, cudagraph_mode, compilation_mode, supported) -if current_platform.is_rocm(): - combo_cases_2 = [ - ("RocmAttn", "FULL", CompilationMode.NONE, True), - ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "NONE", CompilationMode.NONE, True), - ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), - ] -else: - combo_cases_2 = [ - ("FA2", "FULL", CompilationMode.NONE, True), - ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), - ("FA2", "PIECEWISE", CompilationMode.NONE, True), - ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("FA2", "NONE", CompilationMode.NONE, True), - ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), - ] +attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2" + +combo_cases_2 = [ + (attn_backend, "FULL", CompilationMode.NONE, True), + (attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "NONE", CompilationMode.NONE, True), + (attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True), +] @pytest.mark.parametrize(