diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 12621d493e54..b1895e83b8b3 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte # test cudagraph_mode with different compilation mode. # (backend_name, cudagraph_mode, compilation_mode, supported) -if current_platform.is_rocm(): - combo_cases_2 = [ - ("RocmAttn", "FULL", CompilationMode.NONE, True), - ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "NONE", CompilationMode.NONE, True), - ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), - ] -else: - combo_cases_2 = [ - ("FA2", "FULL", CompilationMode.NONE, True), - ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), - ("FA2", "PIECEWISE", CompilationMode.NONE, True), - ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("FA2", "NONE", CompilationMode.NONE, True), - ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), - ] +attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2" + +combo_cases_2 = [ + (attn_backend, "FULL", CompilationMode.NONE, True), + (attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "NONE", CompilationMode.NONE, True), + (attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True), +] @pytest.mark.parametrize(