mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 01:15:44 +08:00
[NIXL] use Host buffer to support TP_ratio > 1 for XPU (#27140)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
parent
9771e0b432
commit
7c4767f1eb
@ -755,6 +755,7 @@ class NixlConnectorWorker:
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||
self._use_pallas = attn_backend == _Backend.PALLAS
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
@ -875,6 +876,20 @@ class NixlConnectorWorker:
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
kv_shape = kv_cache.shape
|
||||
kv_dtype = kv_cache.dtype
|
||||
if (
|
||||
self.kv_cache_layout == "NHD"
|
||||
and self.vllm_config.kv_transfer_config is not None
|
||||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
||||
):
|
||||
logger.info_once(
|
||||
"'enable_permute_local_kv' flag is enabled while "
|
||||
"device KV Layout is NHD. Init host buffer with"
|
||||
" HND to better support Decode/Prefill TP_ratio > 1."
|
||||
)
|
||||
# Since NHD will not support Decode/Prefill TP_ratio > 1,
|
||||
# we can leverage host_buffer for permute
|
||||
self.host_buffer_kv_cache_layout = "HND"
|
||||
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
|
||||
xfer_buffers[layer_name] = torch.empty(
|
||||
kv_shape, dtype=kv_dtype, device="cpu"
|
||||
)
|
||||
@ -1110,7 +1125,9 @@ class NixlConnectorWorker:
|
||||
num_blocks=self.num_blocks,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout,
|
||||
kv_cache_layout=self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
)
|
||||
ready_event, stop_event = threading.Event(), threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
@ -1273,7 +1290,12 @@ class NixlConnectorWorker:
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
)
|
||||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||
kv_cache_layout = (
|
||||
self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout
|
||||
)
|
||||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
|
||||
if (
|
||||
self.kv_transfer_config.enable_permute_local_kv
|
||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||
@ -1299,9 +1321,6 @@ class NixlConnectorWorker:
|
||||
)
|
||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||
else:
|
||||
if tp_ratio > 1 and self.device_type == "xpu":
|
||||
# XPU uses NHD, hence it does not support splitting on H
|
||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, (
|
||||
|
||||
@ -144,6 +144,8 @@ class XPUPlatform(Platform):
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
if parallel_config.world_size > 1:
|
||||
@ -245,6 +247,10 @@ class XPUPlatform(Platform):
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||
# To support TP_ratio, HOST KV might be initiated with HND
|
||||
# while XPU device KV is with NHD
|
||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@ -257,4 +263,8 @@ class XPUPlatform(Platform):
|
||||
) -> None:
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||
# XPU device KV is with NHD while HOST KV
|
||||
# might be initiated with HND for TP_ratio support
|
||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user