From 460d02a417b440ce8b3b8d09c6f5214a2a346426 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 21 Nov 2025 10:55:27 -0600 Subject: [PATCH] [NIXL] Fix after virtual block_size for host_buffer with heter kv_layout (#29122) Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 14 +++++++++++++- vllm/platforms/xpu.py | 8 -------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1626f819af8b5..7c0911240493c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1042,10 +1042,12 @@ class NixlConnectorWorker: NOT directly supported by NIXL (e.g., tpu) """ xfer_buffers: dict[str, torch.Tensor] = {} + inv_order = [0, 1, 3, 2, 4] try: for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype + permute_shape = False if ( self.kv_cache_layout == "NHD" and self.vllm_config.kv_transfer_config is not None @@ -1059,10 +1061,20 @@ class NixlConnectorWorker: # Since NHD will not support Decode/Prefill TP_ratio > 1, # we can leverage host_buffer for permute self.host_buffer_kv_cache_layout = "HND" - kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4]) + kv_shape = ( + tuple(kv_shape[i] for i in inv_order) + if not self.use_mla + else kv_shape + ) + permute_shape = not self.use_mla + xfer_buffers[layer_name] = torch.empty( kv_shape, dtype=kv_dtype, device="cpu" ) + if permute_shape: + xfer_buffers[layer_name] = xfer_buffers[layer_name].permute( + inv_order + ) except MemoryError as e: logger.error("NIXLConnectorWorker gets %s.", e) raise diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 65516827a16da..18a3186b142f1 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -251,10 +251,6 @@ class XPUPlatform(Platform): ) -> None: """Copy blocks from src_cache to dst_cache on XPU.""" _src_cache = src_cache[:, src_block_indices] - if _src_cache.shape[2:] != dst_cache.shape[2:]: - # To support TP_ratio, HOST KV might be initiated with HND - # while XPU device KV is with NHD - _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) @classmethod @@ -267,8 +263,4 @@ class XPUPlatform(Platform): ) -> None: """Copy blocks from XPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] - if _src_cache.shape[2:] != dst_cache.shape[2:]: - # XPU device KV is with NHD while HOST KV - # might be initiated with HND for TP_ratio support - _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.cpu()