Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-08 17:50:07 +00:00
parent 5d45b77124
commit 684c9b7b6d

View File

@ -525,11 +525,15 @@ class NixlConnectorWorker:
if remote_tp_size is None:
assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id]
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
else:
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
@ -538,22 +542,29 @@ class NixlConnectorWorker:
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
# TODO docs
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size
def get_target_remote_rank(
def get_target_remote_ranks(
self,
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> int:
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
return self.tp_rank // tp_ratio
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
else:
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)]
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
@ -638,8 +649,8 @@ class NixlConnectorWorker:
self.copy_blocks: Optional[CopyBlocksOp] = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
# rank may pull from multiple remote TP workers.
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@ -649,7 +660,8 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
# TODO do I need tp_Ratio of this?
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
@ -756,51 +768,52 @@ class NixlConnectorWorker:
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_info.get_target_remote_rank(
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
p_remote_ranks = self.kv_info.get_target_remote_ranks(
remote_tp_size=remote_tp_size
)
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
remote_rank_to_agent_name = {}
for remote_rank in p_remote_ranks:
path = make_zmq_path("tcp", host, port + remote_rank)
logger.warning(
"Querying metadata on path: %s at remote rank %s", path, remote_rank
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
@ -944,7 +957,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
@ -1033,7 +1046,7 @@ class NixlConnectorWorker:
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -1096,7 +1109,7 @@ class NixlConnectorWorker:
engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
logger.debug(
logger.warning(
"Remote agent with engine_id %s and rank"
"%s already exchanged metadata, skip handshake.",
engine_id,
@ -1114,6 +1127,7 @@ class NixlConnectorWorker:
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id)
print("REPLICATES KV CACHE", replicates_kv_cache, "\n")
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
@ -1121,11 +1135,12 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
# 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
# we don't need to use this for spliting the remote kv cache.
tp_ratio = self.kv_info.tp_ratio(engine_id)
### Register remote agent memory regions
@ -1137,7 +1152,8 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# TODO
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
rank_offset = (
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
)
@ -1168,7 +1184,7 @@ class NixlConnectorWorker:
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs
)
@ -1189,7 +1205,6 @@ class NixlConnectorWorker:
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
@ -1217,21 +1232,20 @@ class NixlConnectorWorker:
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
# TODO add asserts for P TP > D TP
# assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
# "Remote P worker KV layer cache must be of shape [2, N, "
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
# )
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
# assert self.block_size == remote_block_size, (
# "Remote P worker with different page/block size is not supported "
# f"{self.block_size=}, {remote_block_size=}"
# )
assert self.block_size == remote_block_size, (
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}"
)
# TP workers have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
# # TP workers have same #blocks.
# assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
@ -1431,17 +1445,24 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote_engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
)
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
# D may perform multiple reads from different remote ranks.
for remote_rank in remote_ranks:
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s for req %s",
meta.remote_engine_id,
remote_rank,
req_id,
)
# TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_rank=remote_rank,
)
def _read_blocks(
self,
@ -1449,6 +1470,7 @@ class NixlConnectorWorker:
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
remote_rank: int,
):
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
@ -1462,14 +1484,14 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_info.tp_ratio(dst_engine_id)
# Cap to 1 when P TP > D TP: only a single rank will read from remote.
tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id))
notif_id = f"{request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
# just notify P worker(s) that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id)
agent_name = self._remote_agents[dst_engine_id][remote_rank]
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
return
@ -1482,7 +1504,7 @@ class NixlConnectorWorker:
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
@ -1571,6 +1593,7 @@ class NixlConnectorWorker:
assert self.num_layers == self.num_regions
region_ids = np.arange(layer_idx, layer_idx + 1)
# TODO can this vary?
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
@ -1617,8 +1640,9 @@ class NixlConnectorWorker:
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():