[ROCm][CI] Fix test_cudagraph_mode.py Failure For AMD CI (#29808)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Micah Williamson 2025-12-02 16:54:36 -06:00 committed by GitHub
parent 1b1e35aaf9
commit c014de1ec7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
if current_platform.is_rocm():
combo_cases_2 = [
("RocmAttn", "FULL", CompilationMode.NONE, True),
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "NONE", CompilationMode.NONE, True),
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
]
else:
combo_cases_2 = [
("FA2", "FULL", CompilationMode.NONE, True),
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
("FA2", "PIECEWISE", CompilationMode.NONE, True),
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("FA2", "NONE", CompilationMode.NONE, True),
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
]
attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2"
combo_cases_2 = [
(attn_backend, "FULL", CompilationMode.NONE, True),
(attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True),
(attn_backend, "PIECEWISE", CompilationMode.NONE, True),
(attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
(attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
(attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
(attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True),
(attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
(attn_backend, "NONE", CompilationMode.NONE, True),
(attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True),
]
@pytest.mark.parametrize(