From f0c503f66e2f6aafa966318d488fd92ac662cdf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Wed, 3 Sep 2025 09:19:54 +0200 Subject: [PATCH] [Nixl] Heterogeneous TP support FlashInfer (#20189) Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 62 ++++++++++++++++--- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index efe023d5595e..8f16babfe2ae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -715,7 +715,7 @@ class NixlConnectorWorker: # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor + # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). split_k_and_v = not (self.use_mla or self._use_pallas_v1 or self._use_flashinfer) @@ -758,12 +758,21 @@ class NixlConnectorWorker: assert tensor_size_bytes % self.num_blocks == 0 self.block_len = tensor_size_bytes // self.num_blocks self.slot_size_bytes = self.block_len // self.block_size + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks if self._use_flashinfer: assert self.slot_size_bytes % 2 == 0 self.slot_size_bytes /= 2 - self.device_kv_caches = kv_caches - self.dst_num_blocks[self.engine_id] = self.num_blocks + # NOTE (NickLucche) When FlashInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence the we double the number + # of 'virtual' regions here and halve `block_len` below. + self.num_regions *= 2 + + kv_block_len = self.get_backend_aware_kv_block_len() # Register local/src descr for NIXL xfer. blocks_data = [] for base_addr in seen_base_addresses: @@ -776,8 +785,18 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) - # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.tp_rank)) + blocks_data.append((addr, kv_block_len, self.tp_rank)) + + if self._use_flashinfer: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) @@ -903,7 +922,7 @@ class NixlConnectorWorker: remote_block_size = nixl_agent_meta.block_len // ( self.slot_size_bytes * tp_ratio) if self._use_flashinfer: - # Account for joint KV in FlashInfer. + # With flashinfer, KV are sent in the same message. remote_block_size //= 2 if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. @@ -929,10 +948,10 @@ class NixlConnectorWorker: # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - # Only register the remote's descriptors if current rank pulls from it. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ + kv_block_len = self.get_backend_aware_kv_block_len() + rank_offset = self.tp_rank % tp_ratio * kv_block_len \ if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: @@ -943,7 +962,16 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, remote_tp_rank)) + + if self._use_flashinfer: + # With FlashInfer index V separately to allow head splitting. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_len // 2 + blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, @@ -1249,6 +1277,22 @@ class NixlConnectorWorker: descs_ids.append(reg_id * num_blocks + block_id) return descs_ids + def get_backend_aware_kv_block_len(self): + """ + Get the block length for one K/V element (K and V have the same size). + + For FA and other backends, this is equal to the length of the whole + block, as K and V are in separate regions. + For FlashInfer, this is half the length of the whole block, as K and V + share the same region. + """ + if self._use_flashinfer: + # For indexing only half (either just the K or V part). + block_len = self.block_len // 2 + else: + block_len = self.block_len + return block_len + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: