diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index 81b57f1ca0c8..0d4fa344d298 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -8,11 +8,20 @@ import torch from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.attention.backends.flashinfer import FlashInferBackend -from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +BACKENDS_TO_TEST = [FlashAttentionBackend] + +if not current_platform.is_rocm(): + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + BACKENDS_TO_TEST.append(FlashInferBackend) + + from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend + + BACKENDS_TO_TEST.append(FlashAttnMLABackend) + NUM_GPU_BLOCKS = [64] NUM_CPU_BLOCKS = [256] GPU_BLOCK_SIZES = [16] @@ -55,8 +64,8 @@ def test_transfer( ) -> None: current_platform.seed_everything(seed) - # create per-layer GPU KV caches - attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] + # create per-layer GPU KV caches based on available attn_backends + attn_backends_list = BACKENDS_TO_TEST gpu_caches = {} attn_backends = {}