diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 406d4c0b4c1f..57474a3dc01e 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -20,6 +20,8 @@ ATTN_BACKENDS = ["FLASH_ATTN"] if current_platform.is_cuda(): ATTN_BACKENDS.append("FLASHINFER") +elif current_platform.is_rocm(): + ATTN_BACKENDS = ["TRITON_ATTN"] class MockSubscriber: