Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-13 14:01:34 +00:00
parent 6601c9c5be
commit b8d520232f

View File

@ -1181,16 +1181,18 @@ class NixlConnectorWorker:
tp_ratio,
)
### (Optional) Register local agent memory regions
if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles:
### (Optional) Register local agent memory regions.
# MLA-optimization: only prepare one region.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_side_chunked_handles
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_side_chunked_handles[tp_ratio] = []
# MLA-optimization: only prepare one region.
# NOTE NickLucche: only a chunk of whole cache is used with MLA!
tp_ratio_opt = 1 if self.use_mla else -tp_ratio
for i in range(tp_ratio_opt):
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
@ -1215,12 +1217,12 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
# Read our whole local region size from remote.
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
if tp_ratio < 0:
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
kv_block_len = kv_block_len // (-tp_ratio)
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0
self.tp_rank % tp_ratio * local_block_len if indexes_into_remote else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
@ -1229,7 +1231,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, remote_tp_rank))
blocks_data.append((addr, local_block_len, remote_tp_rank))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
@ -1237,7 +1239,7 @@ class NixlConnectorWorker:
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
blocks_data.append((v_addr, local_block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
@ -1273,10 +1275,12 @@ class NixlConnectorWorker:
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_info.is_kv_replicated(remote_engine_id))
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_info.is_kv_replicated(remote_engine_id):
if self.kv_info.replicates_kv_cache(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
"KV cache sizes must match between P and D when replicated"
@ -1540,7 +1544,7 @@ class NixlConnectorWorker:
remote_rank,
req_id,
)
if tp_ratio < 0:
if tp_ratio < 0 and not self.use_mla:
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]