[NIXL] use Host buffer to support TP_ratio > 1 for XPU (#27140)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Chendi.Xue 2025-10-22 10:28:13 -05:00 committed by GitHub
parent 9771e0b432
commit 7c4767f1eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 5 deletions

View File

@ -755,6 +755,7 @@ class NixlConnectorWorker:
self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_flashinfer = attn_backend == _Backend.FLASHINFER
self._use_pallas = attn_backend == _Backend.PALLAS self._use_pallas = attn_backend == _Backend.PALLAS
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
@ -875,6 +876,20 @@ class NixlConnectorWorker:
for layer_name, kv_cache in kv_caches.items(): for layer_name, kv_cache in kv_caches.items():
kv_shape = kv_cache.shape kv_shape = kv_cache.shape
kv_dtype = kv_cache.dtype kv_dtype = kv_cache.dtype
if (
self.kv_cache_layout == "NHD"
and self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
):
logger.info_once(
"'enable_permute_local_kv' flag is enabled while "
"device KV Layout is NHD. Init host buffer with"
" HND to better support Decode/Prefill TP_ratio > 1."
)
# Since NHD will not support Decode/Prefill TP_ratio > 1,
# we can leverage host_buffer for permute
self.host_buffer_kv_cache_layout = "HND"
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
xfer_buffers[layer_name] = torch.empty( xfer_buffers[layer_name] = torch.empty(
kv_shape, dtype=kv_dtype, device="cpu" kv_shape, dtype=kv_dtype, device="cpu"
) )
@ -1110,7 +1125,9 @@ class NixlConnectorWorker:
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout, kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
) )
ready_event, stop_event = threading.Event(), threading.Event() ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread( self._nixl_handshake_listener_t = threading.Thread(
@ -1273,7 +1290,12 @@ class NixlConnectorWorker:
assert not self._use_pallas or tp_ratio == 1, ( assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
) )
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: kv_cache_layout = (
self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout
)
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
if ( if (
self.kv_transfer_config.enable_permute_local_kv self.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND" and nixl_agent_meta.kv_cache_layout == "HND"
@ -1299,9 +1321,6 @@ class NixlConnectorWorker:
) )
remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else: else:
if tp_ratio > 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length # When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens: for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, ( assert block_len == remote_block_len, (

View File

@ -144,6 +144,8 @@ class XPUPlatform(Platform):
# check and update parallel config # check and update parallel config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True
if parallel_config.distributed_executor_backend is None: if parallel_config.distributed_executor_backend is None:
if parallel_config.world_size > 1: if parallel_config.world_size > 1:
@ -245,6 +247,10 @@ class XPUPlatform(Platform):
) -> None: ) -> None:
"""Copy blocks from src_cache to dst_cache on XPU.""" """Copy blocks from src_cache to dst_cache on XPU."""
_src_cache = src_cache[:, src_block_indices] _src_cache = src_cache[:, src_block_indices]
if _src_cache.shape[2:] != dst_cache.shape[2:]:
# To support TP_ratio, HOST KV might be initiated with HND
# while XPU device KV is with NHD
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod @classmethod
@ -257,4 +263,8 @@ class XPUPlatform(Platform):
) -> None: ) -> None:
"""Copy blocks from XPU to host (CPU).""" """Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices] _src_cache = src_cache[:, src_block_indices]
if _src_cache.shape[2:] != dst_cache.shape[2:]:
# XPU device KV is with NHD while HOST KV
# might be initiated with HND for TP_ratio support
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
dst_cache[:, dst_block_indices] = _src_cache.cpu() dst_cache[:, dst_block_indices] = _src_cache.cpu()