From e599e2c65ee32abcc986733ab0a55becea158bb4 Mon Sep 17 00:00:00 2001 From: liuzhenwei Date: Fri, 5 Sep 2025 12:03:12 +0800 Subject: [PATCH] [XPU][P/D] Add XPU support in NixlConnector (#22436) Signed-off-by: zhenwei Co-authored-by: Kunshang Ji --- requirements/xpu.txt | 1 + .../kv_transfer/kv_connector/utils.py | 50 ++++++++++++- .../kv_connector/v1/nixl_connector.py | 1 + vllm/platforms/tpu.py | 26 +++++++ vllm/platforms/xpu.py | 31 ++++++++ vllm/v1/worker/gpu_model_runner.py | 4 ++ vllm/v1/worker/tpu_model_runner.py | 72 +------------------ 7 files changed, 114 insertions(+), 71 deletions(-) diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 4607c3efdf14..c44a2a9c74e5 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -10,6 +10,7 @@ wheel jinja2>=3.1.6 datasets # for benchmark scripts numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding +nixl==0.3.0 # for PD disaggregation --extra-index-url=https://download.pytorch.org/whl/xpu torch==2.8.0+xpu torchaudio diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 2364400b3d35..f4dc248a1279 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -6,7 +6,7 @@ KV cache helper for store. from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Optional, cast +from typing import Literal, Optional, Union, cast import torch @@ -196,3 +196,51 @@ class KVOutputAggregator: output_future.add_done_callback(make_callback(i)) return result_future + + +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], +) -> None: + """Copy kv blocks between different buffers.""" + if not src_kv_caches or not dst_kv_caches or \ + not src_block_ids or not dst_block_ids or \ + len(src_block_ids) != len(dst_block_ids): + return + + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) + + from vllm.platforms import current_platform + if direction == "h2d": + copy_fn = current_platform.insert_blocks_to_device + else: + copy_fn = current_platform.swap_out_blocks_to_host + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index de9cbc660666..c2f73fa28155 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -61,6 +61,7 @@ except ImportError: _NIXL_SUPPORTED_XPUS = { "cuda": ("cuda", ), "tpu": ("cpu", ), + "xpu": ("cpu", ), } diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d7468d74b021..6a061956d814 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -200,6 +200,32 @@ class TpuPlatform(Platform): model_config: "ModelConfig") -> bool: return True + @classmethod + @torch.compile(backend="openxla") + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].to( + dst_cache.device) + + @classmethod + @torch.compile(backend="openxla") + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """ tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9f89334e9a8a..32208e7fff01 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -164,6 +164,13 @@ class XPUPlatform(Platform): vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + if (envs.VLLM_KV_CACHE_LAYOUT is None + or envs.VLLM_KV_CACHE_LAYOUT != "NHD"): + os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" + logger.info( + "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels.") + @classmethod def is_pin_memory_available(cls): return True @@ -210,3 +217,27 @@ class XPUPlatform(Platform): @classmethod def opaque_attention_op(cls) -> bool: return True + + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on XPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from XPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4556a51b809d..42baf020e9dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,6 +28,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) @@ -3139,6 +3140,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + if self.device.type == 'xpu': + get_kv_transfer_group().set_host_xfer_buffer_ops( + copy_kv_blocks) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 985d5ba58c49..5947b54d33ce 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -23,6 +23,7 @@ from vllm.config import (ParallelConfig, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] -def _make_src_and_dst_indices( - src_block_ids: list[int], - dst_block_ids: list[int], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor]: - src_indices = torch.tensor(src_block_ids, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_block_ids, - device=dst_device, - dtype=torch.int64) - return src_indices, dst_indices - - -@torch.compile(backend="openxla") -def _insert_blocks_to_tpu( - cpu_cache: torch.Tensor, - tpu_cache: torch.Tensor, - cpu_block_indices: torch.Tensor, - tpu_block_indices: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( - tpu_cache.device) - - -@torch.compile(backend="openxla") -def _swap_out_tpu_blocks( - tpu_cache: torch.Tensor, - cpu_cache: torch.Tensor, - tpu_block_indices: torch.Tensor, - cpu_block_indices: torch.Tensor, -) -> None: - """ tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() - - -def copy_kv_blocks( - src_kv_caches: dict[str, torch.Tensor], - dst_kv_caches: dict[str, torch.Tensor], - src_block_ids: list[int], - dst_block_ids: list[int], - direction: Literal["h2d", "d2h"], -) -> None: - """Copy kv blocks between different buffers.""" - if not src_kv_caches or not dst_kv_caches or \ - not src_block_ids or not dst_block_ids or \ - len(src_block_ids) != len(dst_block_ids): - return - - src_device = next(iter(src_kv_caches.values())).device - dst_device = next(iter(dst_kv_caches.values())).device - - src_indices, dst_indices = _make_src_and_dst_indices( - src_block_ids=src_block_ids, - dst_block_ids=dst_block_ids, - src_device=src_device, - dst_device=dst_device) - - _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ - _swap_out_tpu_blocks - for layer_name in src_kv_caches: - src_tensor = src_kv_caches[layer_name] - dst_tensor = dst_kv_caches[layer_name] - _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) - - def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, page_size: int) -> int: """Calculates the padded number of KV cache update slices to avoid