diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index dea89babd4b4..df3d53332c7c 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -340,4 +340,11 @@ full_cg_backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, ), + "RocmAttn": BackendConfig( + name="RocmAttn", + env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, + comp_config={ + "cudagraph_mode": "FULL", + }, + ), } diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index d6bde16eba36..7f9c2a0571c3 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -35,14 +35,22 @@ def temporary_environ(env_vars): # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) -combo_cases_1 = [ - ("FA3", "FULL", True), - ("FA3", "FULL_AND_PIECEWISE", True), - ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE - ("FA2", "FULL_AND_PIECEWISE", True), - ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE - ("FlashInfer", "FULL_AND_PIECEWISE", True), -] +if current_platform.is_rocm(): + combo_cases_1 = [ + ("RocmAttn", "FULL", True), + ("RocmAttn", "FULL_AND_PIECEWISE", True), + ("TritonAttn", "FULL", True), + ("TritonAttn", "FULL_AND_PIECEWISE", True), + ] +else: + combo_cases_1 = [ + ("FA3", "FULL", True), + ("FA3", "FULL_AND_PIECEWISE", True), + ("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FA2", "FULL_AND_PIECEWISE", True), + ("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE + ("FlashInfer", "FULL_AND_PIECEWISE", True), + ] @pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1) @@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte # test cudagraph_mode with different compilation mode. # (backend_name, cudagraph_mode, compilation_mode, supported) -combo_cases_2 = [ - ("FA2", "FULL", CompilationMode.NONE, True), - ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), - ("FA2", "PIECEWISE", CompilationMode.NONE, False), - ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("FA2", "NONE", CompilationMode.NONE, True), - ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), -] +if current_platform.is_rocm(): + combo_cases_2 = [ + ("RocmAttn", "FULL", CompilationMode.NONE, True), + ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), + ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("RocmAttn", "NONE", CompilationMode.NONE, True), + ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), + ] +else: + combo_cases_2 = [ + ("FA2", "FULL", CompilationMode.NONE, True), + ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), + ("FA2", "PIECEWISE", CompilationMode.NONE, False), + ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("FA2", "NONE", CompilationMode.NONE, True), + ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), + ] @pytest.mark.parametrize( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b0434b9642f0..0483f6c06ada 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -321,8 +321,8 @@ class RocmPlatform(Platform): return AttentionBackendEnum.TRITON_ATTN.get_path() raise RuntimeError( - "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend." + f"Attention backend {selected_backend.name} is not supported on " + "ROCm. Note that V0 attention backends have been removed." ) @classmethod