From d0b178cef199ddef92f4ecdeb0be9ed3bf05c0c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 30 Sep 2025 14:18:29 +0200 Subject: [PATCH] [NIXL] Add support for MLA caches with different latent dim (#25902) Signed-off-by: NickLucche Signed-off-by: Chen Zhang Co-authored-by: Chen Zhang Signed-off-by: simon-mo --- .../kv_connector/unit/test_nixl_connector.py | 13 +-- .../kv_connector/v1/nixl_connector.py | 95 ++++++++++++------- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6b4bd29f18a56..578bf02eb5192 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by # gpu_model_runner. Here we just hardcode some dummy values. - self.slot_size_bytes = 4096 - self.block_len = self.slot_size_bytes * self.block_size + slot_size_bytes = 4096 + self.slot_size_per_layer = [slot_size_bytes] + self.block_len_per_layer = [slot_size_bytes * self.block_size] self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks @@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. @@ -485,8 +486,8 @@ class TestNixlHandshake: worker = connector.connector_worker # Minimal local registration params used by add_remote_agent - worker.slot_size_bytes = 4096 - worker.block_len = worker.slot_size_bytes * worker.block_size + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] worker.num_blocks = 1 worker.dst_num_blocks[worker.engine_id] = worker.num_blocks @@ -498,7 +499,7 @@ class TestNixlHandshake: agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], num_blocks=1, - block_len=worker.block_len, + block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c11189d7ec109..9e657703aa952 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -84,7 +84,7 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int - block_len: int + block_lens: list[int] attn_backend_name: str kv_cache_layout: str @@ -766,6 +766,9 @@ class NixlConnectorWorker: split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) tensor_size_bytes = None + # Enable different block lengths for different layers when MLA is used. + self.block_len_per_layer = list[int]() + self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = cache_or_caches if split_k_and_v else [ cache_or_caches @@ -783,10 +786,25 @@ class NixlConnectorWorker: tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] - assert tensor_size_bytes == curr_tensor_size_bytes, \ - "All kv cache tensors must have the same size" + assert cache.shape[0] == self.num_blocks, \ + "All kv cache tensors must have the same number of blocks" + + self.block_len_per_layer.append(curr_tensor_size_bytes // + self.num_blocks) + self.slot_size_per_layer.append(self.block_len_per_layer[-1] // + self.block_size) + + if not self.use_mla: + # Different kv cache shape is not supported by HeteroTP + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" caches_data.append( - (base_addr, tensor_size_bytes, self.tp_rank, "")) + (base_addr, curr_tensor_size_bytes, self.tp_rank, "")) + + logger.debug("Different block lengths collected: %s", + set(self.block_len_per_layer)) + assert len(self.block_len_per_layer) == len(seen_base_addresses) + assert self.num_blocks != 0 self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) @@ -799,16 +817,12 @@ class NixlConnectorWorker: logger.debug("Done registering descs") self._registered_descs.append(descs) - assert tensor_size_bytes is not None - assert self.num_blocks != 0 - assert tensor_size_bytes % self.num_blocks == 0 - self.block_len = tensor_size_bytes // self.num_blocks - self.slot_size_bytes = self.block_len // self.block_size self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks if self._use_flashinfer: - assert self.slot_size_bytes % 2 == 0 - self.slot_size_bytes /= 2 + for i in range(len(self.slot_size_per_layer)): + assert self.slot_size_per_layer[i] % 2 == 0 + self.slot_size_per_layer[i] //= 2 # NOTE (NickLucche) When FlashInfer is used, memory is registered # with joint KV for each block. This minimizes the overhead in @@ -818,17 +832,17 @@ class NixlConnectorWorker: # of 'virtual' regions here and halve `block_len` below. self.num_regions *= 2 - kv_block_len = self.get_backend_aware_kv_block_len() # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in seen_base_addresses: + for i, base_addr in enumerate(seen_base_addresses): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to # select agent_meta.num_blocks instead of self.num_blocks for # local descr, and that makes handling regular flow less clean. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.tp_rank)) @@ -838,7 +852,7 @@ class NixlConnectorWorker: # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len + block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len @@ -878,7 +892,7 @@ class NixlConnectorWorker: agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - block_len=self.block_len, + block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout) ready_event = threading.Event() @@ -903,7 +917,7 @@ class NixlConnectorWorker: The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. - Here's an example: + Here's an example (non-MLA case): rank_offset p_remote_tp_rank (kv split no) @@ -959,14 +973,20 @@ class NixlConnectorWorker: total_num_kv_heads = self.model_config.get_total_num_kv_heads() is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or is_kv_replicated: - # With MLA the only difference is in the number of blocks. - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes) - assert self.block_len == nixl_agent_meta.block_len + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, \ + "KV cache sizes must match between P and D when replicated" + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0]) else: - remote_block_size = nixl_agent_meta.block_len // ( - self.slot_size_bytes * tp_ratio) + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, \ + "All remote layers must have the same block size" + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio) if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -977,14 +997,14 @@ class NixlConnectorWorker: raise ValueError( "Heterogeneous TP is not supported on XPU") - assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) assert self.block_size == remote_block_size, ( - "Remote P worker with different block size is not supported " - f"{self.block_size=} {remote_block_size=}") + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}") # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -999,13 +1019,16 @@ class NixlConnectorWorker: # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - kv_block_len = self.get_backend_aware_kv_block_len() - rank_offset = self.tp_rank % tp_ratio * kv_block_len \ - if not (self.use_mla or is_kv_replicated) else 0 + + assert len(nixl_agent_meta.kv_caches_base_addr) == len( + self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. - for base_addr in nixl_agent_meta.kv_caches_base_addr: + for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + rank_offset = self.tp_rank % tp_ratio * kv_block_len \ + if not (self.use_mla or is_kv_replicated) else 0 for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * nixl_agent_meta.block_lens[i] # For each block, grab the heads chunk belonging to rank_i # of size remote_nheads // tp_ratio, which correspond to # self.block_len == remote_block_len//tp_ratio bytes. @@ -1016,9 +1039,9 @@ class NixlConnectorWorker: if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_len + block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_len // 2 + v_addr = addr + nixl_agent_meta.block_lens[i] // 2 blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( @@ -1345,7 +1368,7 @@ class NixlConnectorWorker: descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() - def get_backend_aware_kv_block_len(self): + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). @@ -1356,9 +1379,9 @@ class NixlConnectorWorker: """ if self._use_flashinfer: # For indexing only half (either just the K or V part). - block_len = self.block_len // 2 + block_len = self.block_len_per_layer[layer_idx] // 2 else: - block_len = self.block_len + block_len = self.block_len_per_layer[layer_idx] return block_len def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: