From 684c9b7b6dc127928930483ec93de48d2ce532ce Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 17:50:07 +0000 Subject: [PATCH] init Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 202 ++++++++++-------- 1 file changed, 113 insertions(+), 89 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ca53d1df92aec..762e47b7ae35a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -525,11 +525,15 @@ class NixlConnectorWorker: if remote_tp_size is None: assert remote_engine_id is not None remote_tp_size = self.remote_tp_size[remote_engine_id] - assert self.tp_size % remote_tp_size == 0, ( - f"Local tensor parallel size {self.tp_size} is not divisible " - f"by remote tensor parallel size {remote_tp_size}." - ) - return self.tp_size // remote_tp_size + if self.tp_size >= remote_tp_size: + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + else: + # P TP > D TP case, return the ratio as negative + return -remote_tp_size // self.tp_size def is_kv_replicated(self, engine_id: EngineId) -> bool: """ @@ -538,22 +542,29 @@ class NixlConnectorWorker: """ tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) + # TODO docs + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size - def get_target_remote_rank( + def get_target_remote_ranks( self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, - ) -> int: + ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank (on D) will read from. """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) - return self.tp_rank // tp_ratio + if tp_ratio > 0: + return [self.tp_rank // tp_ratio] + else: + # P TP > D TP case, D reads from |tp_ratio| remote workers. + tp_ratio = -tp_ratio + return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: @@ -638,8 +649,8 @@ class NixlConnectorWorker: self.copy_blocks: Optional[CopyBlocksOp] = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local - # rank will still only pull from a single remote TP worker. - self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + # rank may pull from multiple remote TP workers. + self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -649,7 +660,8 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[EngineId, int] = {} + # TODO do I need tp_Ratio of this? + self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -756,51 +768,52 @@ class NixlConnectorWorker: start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - - # Handshake only with the remote TP rank that current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - p_remote_rank = self.kv_info.get_target_remote_rank( + # When target instance TP > local TP, we need to perform multiple + # handshakes. Do it in a single background job for simplicity. + # Regardless, only handshake with the remote TP rank(s) that current + # local rank will read from. Note that With homogeneous TP, + # this happens to be the same single rank_i. + p_remote_ranks = self.kv_info.get_target_remote_ranks( remote_tp_size=remote_tp_size ) - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug( - "Querying metadata on path: %s at remote rank %s", path, p_remote_rank - ) - - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - logger.debug( - "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + remote_rank_to_agent_name = {} + for remote_rank in p_remote_ranks: + path = make_zmq_path("tcp", host, port + remote_rank) + logger.warning( + "Querying metadata on path: %s at remote rank %s", path, remote_rank ) - # Ensure engine id matches. - if metadata.engine_id != expected_engine_id: - raise RuntimeError( - f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}." + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.debug( + "NIXL handshake: get metadata took: %s", got_metadata_time - start_time ) - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, p_remote_rank, remote_tp_size - ) - setup_agent_time = time.perf_counter() - logger.debug( - "NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time, - ) + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) - # Remote rank -> agent name. - return {p_remote_rank: remote_agent_name} + # Register Remote agent. + remote_agent_name = self.add_remote_agent( + metadata, remote_rank, remote_tp_size + ) + setup_agent_time = time.perf_counter() + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) + remote_rank_to_agent_name[remote_rank] = remote_agent_name + + return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ @@ -944,7 +957,7 @@ class NixlConnectorWorker: assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 - self.kv_caches_base_addr[self.engine_id] = seen_base_addresses + self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1033,7 +1046,7 @@ class NixlConnectorWorker: metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -1096,7 +1109,7 @@ class NixlConnectorWorker: engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): - logger.debug( + logger.warning( "Remote agent with engine_id %s and rank" "%s already exchanged metadata, skip handshake.", engine_id, @@ -1114,6 +1127,7 @@ class NixlConnectorWorker: # Handle tp_size>num_kv_heads: replicate KV cache. replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) + print("REPLICATES KV CACHE", replicates_kv_cache, "\n") # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. @@ -1121,11 +1135,12 @@ class NixlConnectorWorker: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. + # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, + # we don't need to use this for spliting the remote kv cache. tp_ratio = self.kv_info.tp_ratio(engine_id) ### Register remote agent memory regions @@ -1137,7 +1152,8 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + # TODO + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) @@ -1168,7 +1184,7 @@ class NixlConnectorWorker: # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist( remote_agent_name, descs ) @@ -1189,7 +1205,6 @@ class NixlConnectorWorker: assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout tp_ratio = self.kv_info.tp_ratio(remote_engine_id) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) @@ -1217,21 +1232,20 @@ class NixlConnectorWorker: if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 + # TODO add asserts for P TP > D TP + # assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + # "Remote P worker KV layer cache must be of shape [2, N, " + # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + # ) - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) + # assert self.block_size == remote_block_size, ( + # "Remote P worker with different page/block size is not supported " + # f"{self.block_size=}, {remote_block_size=}" + # ) - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - - # TP workers have same #blocks. - assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # # TP workers have same #blocks. + # assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + # assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1431,17 +1445,24 @@ class NixlConnectorWorker: self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - logger.debug( - "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, - req_id, - ) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - ) + remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) + # D may perform multiple reads from different remote ranks. + for remote_rank in remote_ranks: + logger.debug( + "Remote agent %s available, calling _read_blocks" + " on remote rank %s for req %s", + meta.remote_engine_id, + remote_rank, + req_id, + ) + # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_rank=remote_rank, + ) def _read_blocks( self, @@ -1449,6 +1470,7 @@ class NixlConnectorWorker: remote_block_ids: list[int], dst_engine_id: str, request_id: str, + remote_rank: int, ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1462,14 +1484,14 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self.kv_info.tp_ratio(dst_engine_id) + # Cap to 1 when P TP > D TP: only a single rank will read from remote. + tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id)) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, - # just notify P worker that we have the blocks we need. + # just notify P worker(s) that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id) agent_name = self._remote_agents[dst_engine_id][remote_rank] self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return @@ -1482,7 +1504,7 @@ class NixlConnectorWorker: # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp @@ -1571,6 +1593,7 @@ class NixlConnectorWorker: assert self.num_layers == self.num_regions region_ids = np.arange(layer_idx, layer_idx + 1) + # TODO can this vary? num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. @@ -1617,8 +1640,9 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 - for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) self.dst_xfer_side_handles.clear() for remote_agents in self._remote_agents.values(): for agent_name in remote_agents.values():