From 7c4767f1eb29f331c817de7e41eb39a06f83420c Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Wed, 22 Oct 2025 10:28:13 -0500 Subject: [PATCH] [NIXL] use Host buffer to support TP_ratio > 1 for XPU (#27140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Chendi Xue Signed-off-by: Chendi.Xue Co-authored-by: Nicolò Lucchesi --- .../kv_connector/v1/nixl_connector.py | 29 +++++++++++++++---- vllm/platforms/xpu.py | 10 +++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ae7144cf78472..77ff687afd9f9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -755,6 +755,7 @@ class NixlConnectorWorker: self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_pallas = attn_backend == _Backend.PALLAS self.kv_cache_layout = get_kv_cache_layout() + self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -875,6 +876,20 @@ class NixlConnectorWorker: for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype + if ( + self.kv_cache_layout == "NHD" + and self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.enable_permute_local_kv + ): + logger.info_once( + "'enable_permute_local_kv' flag is enabled while " + "device KV Layout is NHD. Init host buffer with" + " HND to better support Decode/Prefill TP_ratio > 1." + ) + # Since NHD will not support Decode/Prefill TP_ratio > 1, + # we can leverage host_buffer for permute + self.host_buffer_kv_cache_layout = "HND" + kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4]) xfer_buffers[layer_name] = torch.empty( kv_shape, dtype=kv_dtype, device="cpu" ) @@ -1110,7 +1125,9 @@ class NixlConnectorWorker: num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, - kv_cache_layout=self.kv_cache_layout, + kv_cache_layout=self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout, ) ready_event, stop_event = threading.Event(), threading.Event() self._nixl_handshake_listener_t = threading.Thread( @@ -1273,7 +1290,12 @@ class NixlConnectorWorker: assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) - if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: + kv_cache_layout = ( + self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout + ) + if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout: if ( self.kv_transfer_config.enable_permute_local_kv and nixl_agent_meta.kv_cache_layout == "HND" @@ -1299,9 +1321,6 @@ class NixlConnectorWorker: ) remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: - if tp_ratio > 1 and self.device_type == "xpu": - # XPU uses NHD, hence it does not support splitting on H - raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index db7c3549df519..cd65cba6b492c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -144,6 +144,8 @@ class XPUPlatform(Platform): # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" + if vllm_config.kv_transfer_config is not None: + vllm_config.kv_transfer_config.enable_permute_local_kv = True if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: @@ -245,6 +247,10 @@ class XPUPlatform(Platform): ) -> None: """Copy blocks from src_cache to dst_cache on XPU.""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # To support TP_ratio, HOST KV might be initiated with HND + # while XPU device KV is with NHD + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) @classmethod @@ -257,4 +263,8 @@ class XPUPlatform(Platform): ) -> None: """Copy blocks from XPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # XPU device KV is with NHD while HOST KV + # might be initiated with HND for TP_ratio support + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.cpu()