diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index bc88370791096..3b0f2d102c1ff 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -1,6 +1,31 @@ #!/bin/bash set -xe +# Parse command line arguments +KV_BUFFER_DEVICE="cuda" # Default to cuda +while [[ $# -gt 0 ]]; do + case $1 in + --kv_buffer_device) + KV_BUFFER_DEVICE="$2" + shift 2 + ;; + *) + echo "Unknown option $1" + echo "Usage: $0 [--kv_buffer_device ]" + exit 1 + ;; + esac +done + +echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" + +# Build the kv-transfer-config once +if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +else + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" +fi + # Models to run MODELS=( "Qwen/Qwen3-0.6B" @@ -79,7 +104,7 @@ run_tests_for_model() { # Calculate port number (base port + instance number) PORT=$((8100 + i)) - # Calculate side channel port. Avoid clash with with TP workers. + # Calculate side channel port. Avoid clash with with TP workers. SIDE_CHANNEL_PORT=$((5559 + i)) echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" @@ -93,7 +118,7 @@ run_tests_for_model() { --enforce-eager \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -128,7 +153,7 @@ run_tests_for_model() { --enforce-eager \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $DECODER_TP_SIZE \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh old mode 100644 new mode 100755 index b64461292910d..c48b452e24cd4 --- a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -1,6 +1,33 @@ #!/bin/bash set -xe +# Parse command line arguments +KV_BUFFER_DEVICE="cuda" # Default to cuda +PREFILL_GPU_ID=4 # Default GPU IDs +DECODE_GPU_ID=5 +while [[ $# -gt 0 ]]; do + case $1 in + --kv_buffer_device) + KV_BUFFER_DEVICE="$2" + shift 2 + ;; + *) + echo "Unknown option $1" + echo "Usage: $0 [--kv_buffer_device ]" + exit 1 + ;; + esac +done + +echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)" + +# Build the kv-transfer-config once +if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' +else + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" +fi + # Models to run MODELS=( "Qwen/Qwen3-0.6B" @@ -50,15 +77,15 @@ run_tests_for_model() { # Get model-specific arguments local model_args=$(get_model_args "$model_name") - + # Start prefill instance PREFILL_PORT=8001 - BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \ --port $PREFILL_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -72,11 +99,11 @@ run_tests_for_model() { DECODE_PORT=8002 # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ + BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \ --port $DECODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ - --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" + --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index 9abf4acacfe81..c3d9a3309eb3a 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -28,8 +28,8 @@ class KVTransferConfig: """The engine id for KV transfers.""" kv_buffer_device: Optional[str] = "cuda" - """The device used by kv connector to buffer the KV cache. - Currently only support 'cuda'.""" + """The device used by kv connector to buffer the KV cache. Choices are + 'cuda' and 'cpu'.""" kv_buffer_size: float = 1e9 """The buffer size for TorchDistributedConnector. Measured in number of diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c11189d7ec109..1c7569515dec7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -67,7 +67,10 @@ except ImportError: # Supported platforms and types of kv transfer buffer. # {device: tuple of supported kv buffer types} _NIXL_SUPPORTED_DEVICE = { - "cuda": ("cuda", ), + "cuda": ( + "cuda", + "cpu", + ), "tpu": ("cpu", ), "xpu": ("cpu", ), } @@ -701,6 +704,9 @@ class NixlConnectorWorker: def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """Assign copy (d2h, h2d) operations when host buffer is used.""" + # Set a no-op if the host buffer is not cpu. + if self.kv_buffer_device != "cpu": + return assert self.use_host_buffer self.copy_blocks = copy_operation diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8b9f9f569206f..6738d3dec2861 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -500,6 +500,30 @@ class CudaPlatformBase(Platform): "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on GPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from GPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() + @classmethod def support_hybrid_kv_cache(cls) -> bool: return True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0960fe3a25fb8..f8b0b9cba1bc1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4059,10 +4059,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.drafter.validate_same_kv_cache_group(kv_cache_config) if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) - if self.device.type == 'xpu': - get_kv_transfer_group().set_host_xfer_buffer_ops( - copy_kv_blocks) + kv_transfer_group = get_kv_transfer_group() + kv_transfer_group.register_kv_caches(kv_caches) + kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) if self.dcp_world_size > 1: layer_names = self.attn_groups[0][0].layer_names