Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-09 15:43:43 +00:00
parent 7bb3861faf
commit 9f38fed93c

View File

@ -520,7 +520,9 @@ class NixlConnectorWorker:
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`. If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
if remote_tp_size is None:
assert remote_engine_id is not None
@ -539,7 +541,9 @@ class NixlConnectorWorker:
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool:
def is_kv_replicated(
self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None
) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
@ -549,11 +553,14 @@ class NixlConnectorWorker:
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool:
def replicates_kv_cache(
self,
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return (
self.is_mla
or self.is_kv_replicated(remote_engine_id, remote_tp_size)
return self.is_mla or self.is_kv_replicated(
remote_engine_id, remote_tp_size
)
def get_target_remote_ranks(
@ -563,7 +570,8 @@ class NixlConnectorWorker:
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
if tp_ratio > 0:
@ -573,8 +581,8 @@ class NixlConnectorWorker:
tp_ratio = -tp_ratio
if self.replicates_kv_cache(remote_engine_id, remote_tp_size):
# When cache is replicated on remote, we only need to read
# from one remote.
return [self.tp_rank*tp_ratio]
# from one remote (they all have the same cache).
return [self.tp_rank * tp_ratio]
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def __init__(self, vllm_config: VllmConfig, engine_id: str):
@ -672,12 +680,10 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
# TODO flexible enough to handle different P TP destinations?
# tp_ratio->handles
# Only poulated during handshake when we read from multiple sources
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self.src_xfer_side_chunked_handles: dict[int, list[int]] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
# TODO do I need tp_Ratio of this?
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(
dict
)
@ -1033,7 +1039,9 @@ class NixlConnectorWorker:
self.tp_rank,
)
descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type)
descs = self.nixl_wrapper.get_xfer_descs(
self.src_blocks_data, self.nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs
@ -1093,10 +1101,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case):
Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
@ -1155,35 +1165,36 @@ class NixlConnectorWorker:
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
# we don't need to use this for splitting the remote kv cache.
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_info.tp_ratio(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \
and tp_ratio < 0)
indexes_into_remote = (
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0
)
# When you realize you're in P TP>DTP you have to split your regions
### (Optional) Register local agent memory regions
if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles:
# TODO use positive tp_ratio value?
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_side_chunked_handles[tp_ratio] = []
# This is still needed even for MLA
# TODO actually only needs one!!
# Check if we have a split we can re-use, ie a remote P with same tp_ratio
for i in range(-tp_ratio):
# MLA-optimization: only prepare one region.
# NOTE NickLucche: only a chunk of whole cache is used with MLA!
tp_ratio_opt = 1 if self.use_mla else -tp_ratio
for i in range(tp_ratio_opt):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allow for different
# block sizes per layer
# TODO this needs to be an assert when validating
# TODO is this the right dim we're splitting on? H?
remote_block_len = local_block_len//(-tp_ratio)
# Offset
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank?
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_side_chunked_handles[tp_ratio].append(handle)
@ -1196,8 +1207,11 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
# TODO workaround
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
# Read our whole local region size from remote.
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
if tp_ratio < 0:
# Remote tp is bigger: read a chunk of local region from remote
kv_block_len = kv_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0
)
@ -1262,7 +1276,7 @@ class NixlConnectorWorker:
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else:
if tp_ratio > 1 and self.device_type == "xpu":
if tp_ratio != 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length
@ -1270,26 +1284,42 @@ class NixlConnectorWorker:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
if tp_ratio > 0:
# Remote NHD/H'D*tp_ratio=N -page_size-
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
# Remote tp is smaller: remote block_len size is bigger
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
# Remote NHD/(H'D/tp_ratio)=N -page_size-
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] // (-tp_ratio)
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
# TODO add asserts for P TP > D TP
# assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
# "Remote P worker KV layer cache must be of shape [2, N, "
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
# )
# assert self.block_size == remote_block_size, (
# "Remote P worker with different page/block size is not supported "
# f"{self.block_size=}, {remote_block_size=}"
# )
# We may allow it in the future with logical kvcache manager block_size
assert self.block_size == remote_block_size, (
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}"
)
# # TP workers have same #blocks.
# assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
# TP workers (handhshakes with same remote) have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
@ -1421,7 +1451,7 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
in_progress = []
for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
@ -1430,13 +1460,16 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
in_progress.append((handle, _xfer_stime))
continue
else:
raise RuntimeError("Transfer failed with state %s", xfer_state)
if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -1490,7 +1523,8 @@ class NixlConnectorWorker:
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
# D may perform multiple reads from different remote ranks.
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
logger.debug(
"Remote agent %s available, calling _read_blocks"
@ -1499,13 +1533,17 @@ class NixlConnectorWorker:
remote_rank,
req_id,
)
# TODO refactor properly and ONLY DO THIS FOR PTP>DTP
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
# Get nixl desc handles depending on whether we're reading from
# multiple sources or we're reading a chunk of
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank]
# TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
if tp_ratio < 0:
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
local_xfer_side_handle = self.src_xfer_side_handle
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
@ -1526,6 +1564,10 @@ class NixlConnectorWorker:
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
"""
Post a READ xfer request from a single local worker to a single
remote worker.
"""
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
@ -1543,7 +1585,7 @@ class NixlConnectorWorker:
notif_id = f"{request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker(s) that we have the blocks we need.
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
agent_name = self._remote_agents[dst_engine_id][remote_rank]
@ -1556,10 +1598,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
# local_xfer_side_handle = self.src_xfer_side_handle
# remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
@ -1647,7 +1685,6 @@ class NixlConnectorWorker:
assert self.num_layers == self.num_regions
region_ids = np.arange(layer_idx, layer_idx + 1)
# TODO can this vary?
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
@ -1694,11 +1731,10 @@ class NixlConnectorWorker:
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
if self.src_xfer_side_chunked_handles:
for handles in self.src_xfer_side_chunked_handles.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_side_chunked_handles.clear()
for handles in self.src_xfer_side_chunked_handles.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_side_chunked_handles.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)