diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index a5cb23c4ef0f2..b654ea4298dbb 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -12,7 +12,6 @@ from tqdm import tqdm from vllm import LLM, SamplingParams, TokensPrompt from vllm.config import KVEventsConfig, KVTransferConfig from vllm.distributed.kv_events import BlockStored, KVEventBatch -from vllm.platforms import current_platform CPU_BLOCK_SIZES = [16, 48] @@ -64,9 +63,6 @@ class MockSubscriber: self.sub.close() -@pytest.mark.skipif( - not current_platform.is_cuda(), reason="CPU offloading only supported on CUDA" -) @pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) def test_cpu_offloading(cpu_block_size: int) -> None: """ diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index 250ed5e95af4b..f765d19ea0175 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -51,9 +51,9 @@ class CPUOffloadingSpec(OffloadingSpec): self, kv_caches: dict[str, torch.Tensor] ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: - if not current_platform.is_cuda(): + if not current_platform.is_cuda_alike(): raise Exception( - "CPU Offloading is currently only supported on CUDA GPUs" + "CPU Offloading is currently only supported on CUDA-alike GPUs" ) layer_names = list(kv_caches.keys())