[XPU][P/D] Add XPU support in NixlConnector (#22436)

Signed-off-by: zhenwei <zhenwei.liu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
liuzhenwei 2025-09-05 12:03:12 +08:00 committed by GitHub
parent c29fb540ff
commit e599e2c65e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 114 additions and 71 deletions

View File

@ -10,6 +10,7 @@ wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
nixl==0.3.0 # for PD disaggregation
--extra-index-url=https://download.pytorch.org/whl/xpu
torch==2.8.0+xpu
torchaudio

View File

@ -6,7 +6,7 @@ KV cache helper for store.
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Optional, cast
from typing import Literal, Optional, Union, cast
import torch
@ -196,3 +196,51 @@ class KVOutputAggregator:
output_future.add_done_callback(make_callback(i))
return result_future
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
from vllm.platforms import current_platform
if direction == "h2d":
copy_fn = current_platform.insert_blocks_to_device
else:
copy_fn = current_platform.swap_out_blocks_to_host
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)

View File

@ -61,6 +61,7 @@ except ImportError:
_NIXL_SUPPORTED_XPUS = {
"cuda": ("cuda", ),
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}

View File

@ -200,6 +200,32 @@ class TpuPlatform(Platform):
model_config: "ModelConfig") -> bool:
return True
@classmethod
@torch.compile(backend="openxla")
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
dst_cache.device)
@classmethod
@torch.compile(backend="openxla")
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
""" tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform

View File

@ -164,6 +164,13 @@ class XPUPlatform(Platform):
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
if (envs.VLLM_KV_CACHE_LAYOUT is None
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
logger.info(
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
"only NHD layout is supported by XPU attention kernels.")
@classmethod
def is_pin_memory_available(cls):
return True
@ -210,3 +217,27 @@ class XPUPlatform(Platform):
@classmethod
def opaque_attention_op(cls) -> bool:
return True
@classmethod
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from src_cache to dst_cache on XPU."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu()

View File

@ -28,6 +28,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
@ -3139,6 +3140,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
if self.device.type == 'xpu':
get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks)
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""

View File

@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch
import numpy as np
@ -23,6 +23,7 @@ from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
return paddings[index]
def _make_src_and_dst_indices(
src_block_ids: list[int],
dst_block_ids: list[int],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> tuple[torch.Tensor, torch.Tensor]:
src_indices = torch.tensor(src_block_ids,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_block_ids,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_blocks_to_tpu(
cpu_cache: torch.Tensor,
tpu_cache: torch.Tensor,
cpu_block_indices: torch.Tensor,
tpu_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
tpu_cache.device)
@torch.compile(backend="openxla")
def _swap_out_tpu_blocks(
tpu_cache: torch.Tensor,
cpu_cache: torch.Tensor,
tpu_block_indices: torch.Tensor,
cpu_block_indices: torch.Tensor,
) -> None:
""" tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
def copy_kv_blocks(
src_kv_caches: dict[str, torch.Tensor],
dst_kv_caches: dict[str, torch.Tensor],
src_block_ids: list[int],
dst_block_ids: list[int],
direction: Literal["h2d", "d2h"],
) -> None:
"""Copy kv blocks between different buffers."""
if not src_kv_caches or not dst_kv_caches or \
not src_block_ids or not dst_block_ids or \
len(src_block_ids) != len(dst_block_ids):
return
src_device = next(iter(src_kv_caches.values())).device
dst_device = next(iter(dst_kv_caches.values())).device
src_indices, dst_indices = _make_src_and_dst_indices(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_device=src_device,
dst_device=dst_device)
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
_swap_out_tpu_blocks
for layer_name in src_kv_caches:
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid