[Nixl][P/D] Add cuda2cpu support (HD->DH transfer) (#24690)

Signed-off-by: Chenxi Yang <cxyang@fb.com>
Co-authored-by: Chenxi Yang <cxyang@fb.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Chenxi Yang 2025-09-29 07:31:51 -07:00 committed by yewentao256
parent 9f78b9ca84
commit f84b2a0dd0
6 changed files with 96 additions and 15 deletions

View File

@ -1,6 +1,31 @@
#!/bin/bash
set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
KV_BUFFER_DEVICE="$2"
shift 2
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
exit 1
;;
esac
done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
fi
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
@ -93,7 +118,7 @@ run_tests_for_model() {
--enforce-eager \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
@ -128,7 +153,7 @@ run_tests_for_model() {
--enforce-eager \
--gpu-memory-utilization 0.2 \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"

View File

@ -1,6 +1,33 @@
#!/bin/bash
set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
PREFILL_GPU_ID=4 # Default GPU IDs
DECODE_GPU_ID=5
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
KV_BUFFER_DEVICE="$2"
shift 2
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
exit 1
;;
esac
done
echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)"
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
fi
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
@ -54,11 +81,11 @@ run_tests_for_model() {
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
@ -72,11 +99,11 @@ run_tests_for_model() {
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'"
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"

View File

@ -28,8 +28,8 @@ class KVTransferConfig:
"""The engine id for KV transfers."""
kv_buffer_device: Optional[str] = "cuda"
"""The device used by kv connector to buffer the KV cache.
Currently only support 'cuda'."""
"""The device used by kv connector to buffer the KV cache. Choices are
'cuda' and 'cpu'."""
kv_buffer_size: float = 1e9
"""The buffer size for TorchDistributedConnector. Measured in number of

View File

@ -67,7 +67,10 @@ except ImportError:
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": ("cuda", ),
"cuda": (
"cuda",
"cpu",
),
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}
@ -701,6 +704,9 @@ class NixlConnectorWorker:
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""Assign copy (d2h, h2d) operations when host buffer is used."""
# Set a no-op if the host buffer is not cpu.
if self.kv_buffer_device != "cpu":
return
assert self.use_host_buffer
self.copy_blocks = copy_operation

View File

@ -500,6 +500,30 @@ class CudaPlatformBase(Platform):
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")
@classmethod
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from src_cache to dst_cache on GPU."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from GPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu()
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

View File

@ -4059,10 +4059,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.drafter.validate_same_kv_cache_group(kv_cache_config)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
if self.device.type == 'xpu':
get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks)
kv_transfer_group = get_kv_transfer_group()
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
if self.dcp_world_size > 1:
layer_names = self.attn_groups[0][0].layer_names