[NIXL] Add support for MLA caches with different latent dim (#25902)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-30 14:18:29 +02:00 committed by simon-mo
parent b3230e1ac0
commit d0b178cef1
2 changed files with 66 additions and 42 deletions

View File

@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
slot_size_bytes = 4096
self.slot_size_per_layer = [slot_size_bytes]
self.block_len_per_layer = [slot_size_bytes * self.block_size]
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
@ -485,8 +486,8 @@ class TestNixlHandshake:
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096
worker.block_len = worker.slot_size_bytes * worker.block_size
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
@ -498,7 +499,7 @@ class TestNixlHandshake:
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=worker.block_len,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
)

View File

@ -84,7 +84,7 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str
@ -766,6 +766,9 @@ class NixlConnectorWorker:
split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer)
tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
@ -783,10 +786,25 @@ class NixlConnectorWorker:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
assert cache.shape[0] == self.num_blocks, \
"All kv cache tensors must have the same number of blocks"
self.block_len_per_layer.append(curr_tensor_size_bytes //
self.num_blocks)
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
self.block_size)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
logger.debug("Different block lengths collected: %s",
set(self.block_len_per_layer))
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
@ -799,16 +817,12 @@ class NixlConnectorWorker:
logger.debug("Done registering descs")
self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
@ -818,17 +832,17 @@ class NixlConnectorWorker:
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in seen_base_addresses:
for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
@ -838,7 +852,7 @@ class NixlConnectorWorker:
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
@ -878,7 +892,7 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event()
@ -903,7 +917,7 @@ class NixlConnectorWorker:
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
Here's an example:
Here's an example (non-MLA case):
rank_offset p_remote_tp_rank
(kv split no)
@ -959,14 +973,20 @@ class NixlConnectorWorker:
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
"KV cache sizes must match between P and D when replicated"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0])
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, \
"All remote layers must have the same block size"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio)
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
@ -977,14 +997,14 @@ class NixlConnectorWorker:
raise ValueError(
"Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
@ -999,13 +1019,16 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
@ -1016,9 +1039,9 @@ class NixlConnectorWorker:
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug(
@ -1345,7 +1368,7 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def get_backend_aware_kv_block_len(self):
def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).
@ -1356,9 +1379,9 @@ class NixlConnectorWorker:
"""
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
block_len = self.block_len_per_layer[layer_idx] // 2
else:
block_len = self.block_len
block_len = self.block_len_per_layer[layer_idx]
return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: