From b8d520232fd455155112d6b2cf1a36b64da90aee Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 13 Oct 2025 14:01:34 +0000 Subject: [PATCH] fix mla Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f71addb67191..83d0df41ead1c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1181,16 +1181,18 @@ class NixlConnectorWorker: tp_ratio, ) - ### (Optional) Register local agent memory regions - if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: + ### (Optional) Register local agent memory regions. + # MLA-optimization: only prepare one region. + if ( + tp_ratio < 0 + and not self.use_mla + and tp_ratio not in self.src_xfer_side_chunked_handles + ): # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). self.src_xfer_side_chunked_handles[tp_ratio] = [] - # MLA-optimization: only prepare one region. - # NOTE NickLucche: only a chunk of whole cache is used with MLA! - tp_ratio_opt = 1 if self.use_mla else -tp_ratio - for i in range(tp_ratio_opt): + for i in range(-tp_ratio): blocks_data = [] for memory_region in self.src_blocks_data: addr, local_block_len, own_tp_rank = memory_region @@ -1215,12 +1217,12 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): # Read our whole local region size from remote. - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - if tp_ratio < 0: + local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + if tp_ratio < 0 and not self.use_mla: # Remote tp is bigger: read a chunk of local region from remote - kv_block_len = kv_block_len // (-tp_ratio) + local_block_len = local_block_len // (-tp_ratio) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 + self.tp_rank % tp_ratio * local_block_len if indexes_into_remote else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1229,7 +1231,7 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, remote_tp_rank)) + blocks_data.append((addr, local_block_len, remote_tp_rank)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. @@ -1237,7 +1239,7 @@ class NixlConnectorWorker: block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + blocks_data.append((v_addr, local_block_len, remote_tp_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", @@ -1273,10 +1275,12 @@ class NixlConnectorWorker: assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) + # Num kv_heads > tp_size and P TP > D TP case, not supported + assert not (tp_ratio < 0 and self.kv_info.is_kv_replicated(remote_engine_id)) # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or self.kv_info.is_kv_replicated(remote_engine_id): + if self.kv_info.replicates_kv_cache(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( "KV cache sizes must match between P and D when replicated" @@ -1540,7 +1544,7 @@ class NixlConnectorWorker: remote_rank, req_id, ) - if tp_ratio < 0: + if tp_ratio < 0 and not self.use_mla: # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]