mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:45:01 +08:00
[ROCm][CI] Fix test_cudagraph_mode failure in AMD CI (#29367)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
parent
12c007e288
commit
ef1f7030f0
@ -340,4 +340,11 @@ full_cg_backend_configs = {
|
|||||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"RocmAttn": BackendConfig(
|
||||||
|
name="RocmAttn",
|
||||||
|
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||||
|
comp_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,14 +35,22 @@ def temporary_environ(env_vars):
|
|||||||
|
|
||||||
# test attention backend and cudagraph_mode combo
|
# test attention backend and cudagraph_mode combo
|
||||||
# (backend_name, cudagraph_mode, supported)
|
# (backend_name, cudagraph_mode, supported)
|
||||||
combo_cases_1 = [
|
if current_platform.is_rocm():
|
||||||
("FA3", "FULL", True),
|
combo_cases_1 = [
|
||||||
("FA3", "FULL_AND_PIECEWISE", True),
|
("RocmAttn", "FULL", True),
|
||||||
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
("RocmAttn", "FULL_AND_PIECEWISE", True),
|
||||||
("FA2", "FULL_AND_PIECEWISE", True),
|
("TritonAttn", "FULL", True),
|
||||||
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
("TritonAttn", "FULL_AND_PIECEWISE", True),
|
||||||
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
]
|
||||||
]
|
else:
|
||||||
|
combo_cases_1 = [
|
||||||
|
("FA3", "FULL", True),
|
||||||
|
("FA3", "FULL_AND_PIECEWISE", True),
|
||||||
|
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", True),
|
||||||
|
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
|
||||||
|
("FlashInfer", "FULL_AND_PIECEWISE", True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
|
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1)
|
||||||
@ -92,18 +100,32 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
|||||||
|
|
||||||
# test cudagraph_mode with different compilation mode.
|
# test cudagraph_mode with different compilation mode.
|
||||||
# (backend_name, cudagraph_mode, compilation_mode, supported)
|
# (backend_name, cudagraph_mode, compilation_mode, supported)
|
||||||
combo_cases_2 = [
|
if current_platform.is_rocm():
|
||||||
("FA2", "FULL", CompilationMode.NONE, True),
|
combo_cases_2 = [
|
||||||
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
("RocmAttn", "FULL", CompilationMode.NONE, True),
|
||||||
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
|
||||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||||
("FA2", "NONE", CompilationMode.NONE, True),
|
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
|
("RocmAttn", "NONE", CompilationMode.NONE, True),
|
||||||
]
|
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
combo_cases_2 = [
|
||||||
|
("FA2", "FULL", CompilationMode.NONE, True),
|
||||||
|
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
||||||
|
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
||||||
|
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||||
|
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
("FA2", "NONE", CompilationMode.NONE, True),
|
||||||
|
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -321,8 +321,8 @@ class RocmPlatform(Platform):
|
|||||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||||
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
f"Attention backend {selected_backend.name} is not supported on "
|
||||||
"to select a supported backend."
|
"ROCm. Note that V0 attention backends have been removed."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user