mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
[XPU][P/D] Add XPU support in NixlConnector (#22436)
Signed-off-by: zhenwei <zhenwei.liu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
c29fb540ff
commit
e599e2c65e
@ -10,6 +10,7 @@ wheel
|
||||
jinja2>=3.1.6
|
||||
datasets # for benchmark scripts
|
||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
nixl==0.3.0 # for PD disaggregation
|
||||
--extra-index-url=https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0+xpu
|
||||
torchaudio
|
||||
|
||||
@ -6,7 +6,7 @@ KV cache helper for store.
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import Optional, cast
|
||||
from typing import Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
|
||||
@ -196,3 +196,51 @@ class KVOutputAggregator:
|
||||
output_future.add_done_callback(make_callback(i))
|
||||
|
||||
return result_future
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if not src_kv_caches or not dst_kv_caches or \
|
||||
not src_block_ids or not dst_block_ids or \
|
||||
len(src_block_ids) != len(dst_block_ids):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if direction == "h2d":
|
||||
copy_fn = current_platform.insert_blocks_to_device
|
||||
else:
|
||||
copy_fn = current_platform.swap_out_blocks_to_host
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
@ -61,6 +61,7 @@ except ImportError:
|
||||
_NIXL_SUPPORTED_XPUS = {
|
||||
"cuda": ("cuda", ),
|
||||
"tpu": ("cpu", ),
|
||||
"xpu": ("cpu", ),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -200,6 +200,32 @@ class TpuPlatform(Platform):
|
||||
model_config: "ModelConfig") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
|
||||
dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
""" tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||
|
||||
|
||||
try:
|
||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||
|
||||
@ -164,6 +164,13 @@ class XPUPlatform(Platform):
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if (envs.VLLM_KV_CACHE_LAYOUT is None
|
||||
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
|
||||
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
|
||||
logger.info(
|
||||
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
||||
"only NHD layout is supported by XPU attention kernels.")
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
return True
|
||||
@ -210,3 +217,27 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
@ -3139,6 +3140,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
if self.device.type == 'xpu':
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import bisect
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -23,6 +23,7 @@ from vllm.config import (ParallelConfig, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
||||
return paddings[index]
|
||||
|
||||
|
||||
def _make_src_and_dst_indices(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = torch.tensor(src_block_ids,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_block_ids,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _insert_blocks_to_tpu(
|
||||
cpu_cache: torch.Tensor,
|
||||
tpu_cache: torch.Tensor,
|
||||
cpu_block_indices: torch.Tensor,
|
||||
tpu_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
||||
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
|
||||
tpu_cache.device)
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _swap_out_tpu_blocks(
|
||||
tpu_cache: torch.Tensor,
|
||||
cpu_cache: torch.Tensor,
|
||||
tpu_block_indices: torch.Tensor,
|
||||
cpu_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
""" tpu blocks to cpu blocks"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
||||
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
|
||||
|
||||
|
||||
def copy_kv_blocks(
|
||||
src_kv_caches: dict[str, torch.Tensor],
|
||||
dst_kv_caches: dict[str, torch.Tensor],
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
direction: Literal["h2d", "d2h"],
|
||||
) -> None:
|
||||
"""Copy kv blocks between different buffers."""
|
||||
if not src_kv_caches or not dst_kv_caches or \
|
||||
not src_block_ids or not dst_block_ids or \
|
||||
len(src_block_ids) != len(dst_block_ids):
|
||||
return
|
||||
|
||||
src_device = next(iter(src_kv_caches.values())).device
|
||||
dst_device = next(iter(dst_kv_caches.values())).device
|
||||
|
||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||
src_block_ids=src_block_ids,
|
||||
dst_block_ids=dst_block_ids,
|
||||
src_device=src_device,
|
||||
dst_device=dst_device)
|
||||
|
||||
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
|
||||
_swap_out_tpu_blocks
|
||||
for layer_name in src_kv_caches:
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
||||
page_size: int) -> int:
|
||||
"""Calculates the padded number of KV cache update slices to avoid
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user