Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-08 17:50:07 +00:00
parent 5d45b77124
commit 684c9b7b6d

View File

@ -525,11 +525,15 @@ class NixlConnectorWorker:
if remote_tp_size is None: if remote_tp_size is None:
assert remote_engine_id is not None assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
assert self.tp_size % remote_tp_size == 0, ( if self.tp_size >= remote_tp_size:
f"Local tensor parallel size {self.tp_size} is not divisible " assert self.tp_size % remote_tp_size == 0, (
f"by remote tensor parallel size {remote_tp_size}." f"Local tensor parallel size {self.tp_size} is not divisible "
) f"by remote tensor parallel size {remote_tp_size}."
return self.tp_size // remote_tp_size )
return self.tp_size // remote_tp_size
else:
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def is_kv_replicated(self, engine_id: EngineId) -> bool: def is_kv_replicated(self, engine_id: EngineId) -> bool:
""" """
@ -541,19 +545,26 @@ class NixlConnectorWorker:
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id) # TODO docs
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size
def get_target_remote_rank( def get_target_remote_ranks(
self, self,
remote_engine_id: Optional[EngineId] = None, remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None, remote_tp_size: Optional[int] = None,
) -> int: ) -> list[int]:
""" """
Get the remote TP rank (on P) that the current local TP rank Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. (on D) will read from.
""" """
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
return self.tp_rank // tp_ratio if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
else:
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)]
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None: if NixlWrapper is None:
@ -638,8 +649,8 @@ class NixlConnectorWorker:
self.copy_blocks: Optional[CopyBlocksOp] = None self.copy_blocks: Optional[CopyBlocksOp] = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local # Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker. # rank may pull from multiple remote TP workers.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {} self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict)
# Number of NIXL regions. Currently one region per cache # Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer) # (so 1 per layer for MLA, otherwise 2 per layer)
@ -649,7 +660,8 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle. # nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0 self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)]. # Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {} # TODO do I need tp_Ratio of this?
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will # Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks. # have the same number of blocks.
@ -756,51 +768,52 @@ class NixlConnectorWorker:
start_time = time.perf_counter() start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is # When target instance TP > local TP, we need to perform multiple
# a hack to keep us moving. We will switch when moving to etcd # handshakes. Do it in a single background job for simplicity.
# or where we have a single ZMQ socket in the scheduler. # Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# Handshake only with the remote TP rank that current local rank will # this happens to be the same single rank_i.
# pull from. With homogeneous TP it happens to be the same rank_i. p_remote_ranks = self.kv_info.get_target_remote_ranks(
p_remote_rank = self.kv_info.get_target_remote_rank(
remote_tp_size=remote_tp_size remote_tp_size=remote_tp_size
) )
path = make_zmq_path("tcp", host, port + p_remote_rank) remote_rank_to_agent_name = {}
logger.debug( for remote_rank in p_remote_ranks:
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank path = make_zmq_path("tcp", host, port + remote_rank)
) logger.warning(
"Querying metadata on path: %s at remote rank %s", path, remote_rank
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
) )
# Ensure engine id matches. # Send query for the request.
if metadata.engine_id != expected_engine_id: with zmq_ctx(zmq.REQ, path) as sock:
raise RuntimeError( sock.send(GET_META_MSG)
f"Remote NIXL agent engine ID mismatch. " metadata_bytes = sock.recv()
f"Expected {expected_engine_id}," decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
f"received {metadata.engine_id}." metadata = decoder.decode(metadata_bytes)
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
) )
# Register Remote agent. # Ensure engine id matches.
remote_agent_name = self.add_remote_agent( if metadata.engine_id != expected_engine_id:
metadata, p_remote_rank, remote_tp_size raise RuntimeError(
) f"Remote NIXL agent engine ID mismatch. "
setup_agent_time = time.perf_counter() f"Expected {expected_engine_id},"
logger.debug( f"received {metadata.engine_id}."
"NIXL handshake: add agent took: %s", )
setup_agent_time - got_metadata_time,
)
# Remote rank -> agent name. # Register Remote agent.
return {p_remote_rank: remote_agent_name} remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
""" """
@ -944,7 +957,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses) assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0 assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys()) self.num_layers = len(xfer_buffers.keys())
@ -1033,7 +1046,7 @@ class NixlConnectorWorker:
metadata = NixlAgentMetadata( metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
@ -1096,7 +1109,7 @@ class NixlConnectorWorker:
engine_id = nixl_agent_meta.engine_id engine_id = nixl_agent_meta.engine_id
# TODO re-evaluate refreshing for scaling/recovery # TODO re-evaluate refreshing for scaling/recovery
if remote_tp_rank in self._remote_agents.get(engine_id, {}): if remote_tp_rank in self._remote_agents.get(engine_id, {}):
logger.debug( logger.warning(
"Remote agent with engine_id %s and rank" "Remote agent with engine_id %s and rank"
"%s already exchanged metadata, skip handshake.", "%s already exchanged metadata, skip handshake.",
engine_id, engine_id,
@ -1114,6 +1127,7 @@ class NixlConnectorWorker:
# Handle tp_size>num_kv_heads: replicate KV cache. # Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id)
print("REPLICATES KV CACHE", replicates_kv_cache, "\n")
# Create dst descs and xfer side handles. TP workers have same #blocks # Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id. # so we only register once per engine_id.
@ -1121,11 +1135,12 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses. # Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is # Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match. # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
# we don't need to use this for spliting the remote kv cache.
tp_ratio = self.kv_info.tp_ratio(engine_id) tp_ratio = self.kv_info.tp_ratio(engine_id)
### Register remote agent memory regions ### Register remote agent memory regions
@ -1137,7 +1152,8 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # TODO
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
rank_offset = ( rank_offset = (
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
) )
@ -1168,7 +1184,7 @@ class NixlConnectorWorker:
# Register with NIXL. # Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs remote_agent_name, descs
) )
@ -1189,7 +1205,6 @@ class NixlConnectorWorker:
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
tp_ratio = self.kv_info.tp_ratio(remote_engine_id) tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas or tp_ratio == 1, ( assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
) )
@ -1217,21 +1232,20 @@ class NixlConnectorWorker:
if self._use_flashinfer: if self._use_flashinfer:
# With flashinfer, KV are sent in the same message. # With flashinfer, KV are sent in the same message.
remote_block_size //= 2 remote_block_size //= 2
# TODO add asserts for P TP > D TP
# assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
# "Remote P worker KV layer cache must be of shape [2, N, "
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
# )
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( # assert self.block_size == remote_block_size, (
"Remote P worker KV layer cache must be of shape [2, N, " # "Remote P worker with different page/block size is not supported "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." # f"{self.block_size=}, {remote_block_size=}"
) # )
assert self.block_size == remote_block_size, ( # # TP workers have same #blocks.
"Remote P worker with different page/block size is not supported " # assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
f"{self.block_size=}, {remote_block_size=}" # assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
)
# TP workers have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device.""" """copy recved kv from host buffer to device."""
@ -1431,17 +1445,24 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug( remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
"Remote agent %s available, calling _read_blocks for req %s", # D may perform multiple reads from different remote ranks.
meta.remote_engine_id, for remote_rank in remote_ranks:
req_id, logger.debug(
) "Remote agent %s available, calling _read_blocks"
self._read_blocks( " on remote rank %s for req %s",
request_id=req_id, meta.remote_engine_id,
dst_engine_id=meta.remote_engine_id, remote_rank,
local_block_ids=meta.local_block_ids, req_id,
remote_block_ids=meta.remote_block_ids, )
) # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_rank=remote_rank,
)
def _read_blocks( def _read_blocks(
self, self,
@ -1449,6 +1470,7 @@ class NixlConnectorWorker:
remote_block_ids: list[int], remote_block_ids: list[int],
dst_engine_id: str, dst_engine_id: str,
request_id: str, request_id: str,
remote_rank: int,
): ):
# NOTE(rob): having the staging blocks be on the READER side is # NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors). # not going to work well (since we will have to call rearrange tensors).
@ -1462,14 +1484,14 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio # Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks. # on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_info.tp_ratio(dst_engine_id) # Cap to 1 when P TP > D TP: only a single rank will read from remote.
tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id))
notif_id = f"{request_id}:{tp_ratio}".encode() notif_id = f"{request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks, # Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need. # just notify P worker(s) that we have the blocks we need.
num_local_blocks = len(local_block_ids) num_local_blocks = len(local_block_ids)
if num_local_blocks == 0: if num_local_blocks == 0:
remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id)
agent_name = self._remote_agents[dst_engine_id][remote_rank] agent_name = self._remote_agents[dst_engine_id][remote_rank]
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
return return
@ -1482,7 +1504,7 @@ class NixlConnectorWorker:
# Get side handles. # Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
@ -1571,6 +1593,7 @@ class NixlConnectorWorker:
assert self.num_layers == self.num_regions assert self.num_layers == self.num_regions
region_ids = np.arange(layer_idx, layer_idx + 1) region_ids = np.arange(layer_idx, layer_idx + 1)
# TODO can this vary?
num_blocks = self.dst_num_blocks[engine_id] num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block. # Compute the desc ids for each block.
@ -1617,8 +1640,9 @@ class NixlConnectorWorker:
if self.src_xfer_side_handle: if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0 self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear() self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values(): for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values(): for agent_name in remote_agents.values():