mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 14:35:39 +08:00
[XPU][P/D] Add XPU support in NixlConnector (#22436)
Signed-off-by: zhenwei <zhenwei.liu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
c29fb540ff
commit
e599e2c65e
@ -10,6 +10,7 @@ wheel
|
|||||||
jinja2>=3.1.6
|
jinja2>=3.1.6
|
||||||
datasets # for benchmark scripts
|
datasets # for benchmark scripts
|
||||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||||
|
nixl==0.3.0 # for PD disaggregation
|
||||||
--extra-index-url=https://download.pytorch.org/whl/xpu
|
--extra-index-url=https://download.pytorch.org/whl/xpu
|
||||||
torch==2.8.0+xpu
|
torch==2.8.0+xpu
|
||||||
torchaudio
|
torchaudio
|
||||||
|
|||||||
@ -6,7 +6,7 @@ KV cache helper for store.
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from concurrent.futures import CancelledError, Future
|
from concurrent.futures import CancelledError, Future
|
||||||
from typing import Optional, cast
|
from typing import Literal, Optional, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -196,3 +196,51 @@ class KVOutputAggregator:
|
|||||||
output_future.add_done_callback(make_callback(i))
|
output_future.add_done_callback(make_callback(i))
|
||||||
|
|
||||||
return result_future
|
return result_future
|
||||||
|
|
||||||
|
|
||||||
|
def _make_src_and_dst_indices(
|
||||||
|
src_block_ids: list[int],
|
||||||
|
dst_block_ids: list[int],
|
||||||
|
src_device: Union[torch.device, str],
|
||||||
|
dst_device: Union[torch.device, str],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
src_indices = torch.tensor(src_block_ids,
|
||||||
|
device=src_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
dst_indices = torch.tensor(dst_block_ids,
|
||||||
|
device=dst_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
return src_indices, dst_indices
|
||||||
|
|
||||||
|
|
||||||
|
def copy_kv_blocks(
|
||||||
|
src_kv_caches: dict[str, torch.Tensor],
|
||||||
|
dst_kv_caches: dict[str, torch.Tensor],
|
||||||
|
src_block_ids: list[int],
|
||||||
|
dst_block_ids: list[int],
|
||||||
|
direction: Literal["h2d", "d2h"],
|
||||||
|
) -> None:
|
||||||
|
"""Copy kv blocks between different buffers."""
|
||||||
|
if not src_kv_caches or not dst_kv_caches or \
|
||||||
|
not src_block_ids or not dst_block_ids or \
|
||||||
|
len(src_block_ids) != len(dst_block_ids):
|
||||||
|
return
|
||||||
|
|
||||||
|
src_device = next(iter(src_kv_caches.values())).device
|
||||||
|
dst_device = next(iter(dst_kv_caches.values())).device
|
||||||
|
|
||||||
|
src_indices, dst_indices = _make_src_and_dst_indices(
|
||||||
|
src_block_ids=src_block_ids,
|
||||||
|
dst_block_ids=dst_block_ids,
|
||||||
|
src_device=src_device,
|
||||||
|
dst_device=dst_device)
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
if direction == "h2d":
|
||||||
|
copy_fn = current_platform.insert_blocks_to_device
|
||||||
|
else:
|
||||||
|
copy_fn = current_platform.swap_out_blocks_to_host
|
||||||
|
for layer_name in src_kv_caches:
|
||||||
|
src_tensor = src_kv_caches[layer_name]
|
||||||
|
dst_tensor = dst_kv_caches[layer_name]
|
||||||
|
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||||
|
|||||||
@ -61,6 +61,7 @@ except ImportError:
|
|||||||
_NIXL_SUPPORTED_XPUS = {
|
_NIXL_SUPPORTED_XPUS = {
|
||||||
"cuda": ("cuda", ),
|
"cuda": ("cuda", ),
|
||||||
"tpu": ("cpu", ),
|
"tpu": ("cpu", ),
|
||||||
|
"xpu": ("cpu", ),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -200,6 +200,32 @@ class TpuPlatform(Platform):
|
|||||||
model_config: "ModelConfig") -> bool:
|
model_config: "ModelConfig") -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
def insert_blocks_to_device(
|
||||||
|
cls,
|
||||||
|
src_cache: torch.Tensor,
|
||||||
|
dst_cache: torch.Tensor,
|
||||||
|
src_block_indices: torch.Tensor,
|
||||||
|
dst_block_indices: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
|
||||||
|
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
|
||||||
|
dst_cache.device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
def swap_out_blocks_to_host(
|
||||||
|
cls,
|
||||||
|
src_cache: torch.Tensor,
|
||||||
|
dst_cache: torch.Tensor,
|
||||||
|
src_block_indices: torch.Tensor,
|
||||||
|
dst_block_indices: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
""" tpu blocks to cpu blocks"""
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
|
||||||
|
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||||
|
|||||||
@ -164,6 +164,13 @@ class XPUPlatform(Platform):
|
|||||||
vllm_config.scheduler_config.max_model_len,
|
vllm_config.scheduler_config.max_model_len,
|
||||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||||
|
|
||||||
|
if (envs.VLLM_KV_CACHE_LAYOUT is None
|
||||||
|
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
|
||||||
|
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
|
||||||
|
logger.info(
|
||||||
|
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
||||||
|
"only NHD layout is supported by XPU attention kernels.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
return True
|
return True
|
||||||
@ -210,3 +217,27 @@ class XPUPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def opaque_attention_op(cls) -> bool:
|
def opaque_attention_op(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def insert_blocks_to_device(
|
||||||
|
cls,
|
||||||
|
src_cache: torch.Tensor,
|
||||||
|
dst_cache: torch.Tensor,
|
||||||
|
src_block_indices: torch.Tensor,
|
||||||
|
dst_block_indices: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||||
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
|
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def swap_out_blocks_to_host(
|
||||||
|
cls,
|
||||||
|
src_cache: torch.Tensor,
|
||||||
|
dst_cache: torch.Tensor,
|
||||||
|
src_block_indices: torch.Tensor,
|
||||||
|
dst_block_indices: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Copy blocks from XPU to host (CPU)."""
|
||||||
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
|
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
|||||||
from vllm.distributed.eplb.eplb_state import EplbState
|
from vllm.distributed.eplb.eplb_state import EplbState
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||||
prepare_communication_buffer_for_model)
|
prepare_communication_buffer_for_model)
|
||||||
@ -3139,6 +3140,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||||
|
if self.device.type == 'xpu':
|
||||||
|
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||||
|
copy_kv_blocks)
|
||||||
|
|
||||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,6 +23,7 @@ from vllm.config import (ParallelConfig, VllmConfig,
|
|||||||
get_layers_from_vllm_config, update_config)
|
get_layers_from_vllm_config, update_config)
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import BaseLayerWithLoRA
|
from vllm.lora.layers import BaseLayerWithLoRA
|
||||||
@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
|||||||
return paddings[index]
|
return paddings[index]
|
||||||
|
|
||||||
|
|
||||||
def _make_src_and_dst_indices(
|
|
||||||
src_block_ids: list[int],
|
|
||||||
dst_block_ids: list[int],
|
|
||||||
src_device: Union[torch.device, str],
|
|
||||||
dst_device: Union[torch.device, str],
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
src_indices = torch.tensor(src_block_ids,
|
|
||||||
device=src_device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
dst_indices = torch.tensor(dst_block_ids,
|
|
||||||
device=dst_device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
return src_indices, dst_indices
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(backend="openxla")
|
|
||||||
def _insert_blocks_to_tpu(
|
|
||||||
cpu_cache: torch.Tensor,
|
|
||||||
tpu_cache: torch.Tensor,
|
|
||||||
cpu_block_indices: torch.Tensor,
|
|
||||||
tpu_block_indices: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
|
||||||
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
|
|
||||||
tpu_cache.device)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(backend="openxla")
|
|
||||||
def _swap_out_tpu_blocks(
|
|
||||||
tpu_cache: torch.Tensor,
|
|
||||||
cpu_cache: torch.Tensor,
|
|
||||||
tpu_block_indices: torch.Tensor,
|
|
||||||
cpu_block_indices: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
""" tpu blocks to cpu blocks"""
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
|
|
||||||
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
|
|
||||||
|
|
||||||
|
|
||||||
def copy_kv_blocks(
|
|
||||||
src_kv_caches: dict[str, torch.Tensor],
|
|
||||||
dst_kv_caches: dict[str, torch.Tensor],
|
|
||||||
src_block_ids: list[int],
|
|
||||||
dst_block_ids: list[int],
|
|
||||||
direction: Literal["h2d", "d2h"],
|
|
||||||
) -> None:
|
|
||||||
"""Copy kv blocks between different buffers."""
|
|
||||||
if not src_kv_caches or not dst_kv_caches or \
|
|
||||||
not src_block_ids or not dst_block_ids or \
|
|
||||||
len(src_block_ids) != len(dst_block_ids):
|
|
||||||
return
|
|
||||||
|
|
||||||
src_device = next(iter(src_kv_caches.values())).device
|
|
||||||
dst_device = next(iter(dst_kv_caches.values())).device
|
|
||||||
|
|
||||||
src_indices, dst_indices = _make_src_and_dst_indices(
|
|
||||||
src_block_ids=src_block_ids,
|
|
||||||
dst_block_ids=dst_block_ids,
|
|
||||||
src_device=src_device,
|
|
||||||
dst_device=dst_device)
|
|
||||||
|
|
||||||
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
|
|
||||||
_swap_out_tpu_blocks
|
|
||||||
for layer_name in src_kv_caches:
|
|
||||||
src_tensor = src_kv_caches[layer_name]
|
|
||||||
dst_tensor = dst_kv_caches[layer_name]
|
|
||||||
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
||||||
page_size: int) -> int:
|
page_size: int) -> int:
|
||||||
"""Calculates the padded number of KV cache update slices to avoid
|
"""Calculates the padded number of KV cache update slices to avoid
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user