[NIXL] Add support for MLA caches with different latent dim (#25902)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-30 14:18:29 +02:00 committed by simon-mo
parent b3230e1ac0
commit d0b178cef1
2 changed files with 66 additions and 42 deletions

View File

@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
time.sleep(self._hand_shake_latency) time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by # These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values. # gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096 slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size self.slot_size_per_layer = [slot_size_bytes]
self.block_len_per_layer = [slot_size_bytes * self.block_size]
self.num_blocks = 1 self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
num_blocks=1, num_blocks=1,
block_len=self.block_len, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine # `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here. # is started. We mock HND here.
@ -485,8 +486,8 @@ class TestNixlHandshake:
worker = connector.connector_worker worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent # Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096 worker.slot_size_per_layer = [4096]
worker.block_len = worker.slot_size_bytes * worker.block_size worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1 worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
@ -498,7 +499,7 @@ class TestNixlHandshake:
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
num_blocks=1, num_blocks=1,
block_len=worker.block_len, block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name, attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout, kv_cache_layout=mismatched_layout,
) )

View File

@ -84,7 +84,7 @@ class NixlAgentMetadata(
agent_metadata: bytes agent_metadata: bytes
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
num_blocks: int num_blocks: int
block_len: int block_lens: list[int]
attn_backend_name: str attn_backend_name: str
kv_cache_layout: str kv_cache_layout: str
@ -766,6 +766,9 @@ class NixlConnectorWorker:
split_k_and_v = not (self.use_mla or self._use_pallas split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer) or self._use_flashinfer)
tensor_size_bytes = None tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [ cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches cache_or_caches
@ -783,10 +786,25 @@ class NixlConnectorWorker:
tensor_size_bytes = curr_tensor_size_bytes tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0] self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \ assert cache.shape[0] == self.num_blocks, \
"All kv cache tensors must have the same size" "All kv cache tensors must have the same number of blocks"
self.block_len_per_layer.append(curr_tensor_size_bytes //
self.num_blocks)
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
self.block_size)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append( caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, "")) (base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
logger.debug("Different block lengths collected: %s",
set(self.block_len_per_layer))
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
@ -799,16 +817,12 @@ class NixlConnectorWorker:
logger.debug("Done registering descs") logger.debug("Done registering descs")
self._registered_descs.append(descs) self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer: if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0 for i in range(len(self.slot_size_per_layer)):
self.slot_size_bytes /= 2 assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered # NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in # with joint KV for each block. This minimizes the overhead in
@ -818,17 +832,17 @@ class NixlConnectorWorker:
# of 'virtual' regions here and halve `block_len` below. # of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2 self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
blocks_data = [] blocks_data = []
for base_addr in seen_base_addresses: for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are # NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to # could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for # select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean. # local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks): for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset addr = base_addr + block_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank)) blocks_data.append((addr, kv_block_len, self.tp_rank))
@ -838,7 +852,7 @@ class NixlConnectorWorker:
# descs ordering. This is needed for selecting contiguous heads # descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks. # when split across TP ranks.
for block_id in range(self.num_blocks): for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset addr = base_addr + block_offset
# Register addresses for V cache (K registered first). # Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len v_addr = addr + kv_block_len
@ -878,7 +892,7 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_len=self.block_len, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout) kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event() ready_event = threading.Event()
@ -903,7 +917,7 @@ class NixlConnectorWorker:
The latter, assuming D.world_size > P.world_size, requires that two or The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker. more local TP worker share the xfer from a single TP worker.
Here's an example: Here's an example (non-MLA case):
rank_offset p_remote_tp_rank rank_offset p_remote_tp_rank
(kv split no) (kv split no)
@ -959,14 +973,20 @@ class NixlConnectorWorker:
total_num_kv_heads = self.model_config.get_total_num_kv_heads() total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or is_kv_replicated: if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks. # With replicated KV cache, only the number of blocks can differ.
remote_block_size = nixl_agent_meta.block_len // ( assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
self.slot_size_bytes) "KV cache sizes must match between P and D when replicated"
assert self.block_len == nixl_agent_meta.block_len remote_block_size = remote_block_len // (
self.slot_size_per_layer[0])
else: else:
remote_block_size = nixl_agent_meta.block_len // ( # When MLA is not used, this is a list of the same block length
self.slot_size_bytes * tp_ratio) for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, \
"All remote layers must have the same block size"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio)
if self._use_flashinfer: if self._use_flashinfer:
# With flashinfer, KV are sent in the same message. # With flashinfer, KV are sent in the same message.
remote_block_size //= 2 remote_block_size //= 2
@ -977,14 +997,14 @@ class NixlConnectorWorker:
raise ValueError( raise ValueError(
"Heterogeneous TP is not supported on XPU") "Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, " "Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
) )
assert self.block_size == remote_block_size, ( assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported " "Remote P worker with different page/block size is not supported "
f"{self.block_size=} {remote_block_size=}") f"{self.block_size=}, {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks. # Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks: if engine_id in self.dst_num_blocks:
@ -999,13 +1019,16 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[ self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr engine_id] = nixl_agent_meta.kv_caches_base_addr
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \ assert len(nixl_agent_meta.kv_caches_base_addr) == len(
if not (self.use_mla or is_kv_replicated) else 0 self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr: for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i # For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to # of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
@ -1016,9 +1039,9 @@ class NixlConnectorWorker:
if self._use_flashinfer: if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2 v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug( logger.debug(
@ -1345,7 +1368,7 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten() return descs_ids.flatten()
def get_backend_aware_kv_block_len(self): def get_backend_aware_kv_block_len(self, layer_idx: int):
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).
@ -1356,9 +1379,9 @@ class NixlConnectorWorker:
""" """
if self._use_flashinfer: if self._use_flashinfer:
# For indexing only half (either just the K or V part). # For indexing only half (either just the K or V part).
block_len = self.block_len // 2 block_len = self.block_len_per_layer[layer_idx] // 2
else: else:
block_len = self.block_len block_len = self.block_len_per_layer[layer_idx]
return block_len return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: