mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[NIXL] Add support for MLA caches with different latent dim (#25902)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
b3230e1ac0
commit
d0b178cef1
@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
# gpu_model_runner. Here we just hardcode some dummy values.
|
||||
self.slot_size_bytes = 4096
|
||||
self.block_len = self.slot_size_bytes * self.block_size
|
||||
slot_size_bytes = 4096
|
||||
self.slot_size_per_layer = [slot_size_bytes]
|
||||
self.block_len_per_layer = [slot_size_bytes * self.block_size]
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
@ -485,8 +486,8 @@ class TestNixlHandshake:
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
@ -498,7 +499,7 @@ class TestNixlHandshake:
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=worker.block_len,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
@ -84,7 +84,7 @@ class NixlAgentMetadata(
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
|
||||
@ -766,6 +766,9 @@ class NixlConnectorWorker:
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [
|
||||
cache_or_caches
|
||||
@ -783,10 +786,25 @@ class NixlConnectorWorker:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
assert cache.shape[0] == self.num_blocks, \
|
||||
"All kv cache tensors must have the same number of blocks"
|
||||
|
||||
self.block_len_per_layer.append(curr_tensor_size_bytes //
|
||||
self.num_blocks)
|
||||
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
|
||||
self.block_size)
|
||||
|
||||
if not self.use_mla:
|
||||
# Different kv cache shape is not supported by HeteroTP
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
caches_data.append(
|
||||
(base_addr, tensor_size_bytes, self.tp_rank, ""))
|
||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
|
||||
|
||||
logger.debug("Different block lengths collected: %s",
|
||||
set(self.block_len_per_layer))
|
||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||
assert self.num_blocks != 0
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
@ -799,16 +817,12 @@ class NixlConnectorWorker:
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
for i in range(len(self.slot_size_per_layer)):
|
||||
assert self.slot_size_per_layer[i] % 2 == 0
|
||||
self.slot_size_per_layer[i] //= 2
|
||||
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
@ -818,17 +832,17 @@ class NixlConnectorWorker:
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
self.num_regions *= 2
|
||||
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in seen_base_addresses:
|
||||
for i, base_addr in enumerate(seen_base_addresses):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
@ -838,7 +852,7 @@ class NixlConnectorWorker:
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
@ -878,7 +892,7 @@ class NixlConnectorWorker:
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout)
|
||||
ready_event = threading.Event()
|
||||
@ -903,7 +917,7 @@ class NixlConnectorWorker:
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
|
||||
Here's an example:
|
||||
Here's an example (non-MLA case):
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
@ -959,14 +973,20 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes)
|
||||
assert self.block_len == nixl_agent_meta.block_len
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0])
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, \
|
||||
"All remote layers must have the same block size"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio)
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
@ -977,14 +997,14 @@ class NixlConnectorWorker:
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported on XPU")
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different block size is not supported "
|
||||
f"{self.block_size=} {remote_block_size=}")
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}")
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
@ -999,13 +1019,16 @@ class NixlConnectorWorker:
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
|
||||
self.block_len_per_layer)
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
@ -1016,9 +1039,9 @@ class NixlConnectorWorker:
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_len // 2
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
|
||||
logger.debug(
|
||||
@ -1345,7 +1368,7 @@ class NixlConnectorWorker:
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def get_backend_aware_kv_block_len(self):
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -1356,9 +1379,9 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
if self._use_flashinfer:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len // 2
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
else:
|
||||
block_len = self.block_len
|
||||
block_len = self.block_len_per_layer[layer_idx]
|
||||
return block_len
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user