diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1871138b088ca..66d692e7387da 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -520,7 +520,9 @@ class NixlConnectorWorker: Calculate the tensor parallel ratio between local and remote TP. We can think of it as the number of local TP workers-per-remote TP workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`. + groups of size `tp_ratio`. If remote tp_size > local tp_size, the + ratio is flipped (remote_size/local_size) and the returned value is + negative. """ if remote_tp_size is None: assert remote_engine_id is not None @@ -539,7 +541,9 @@ class NixlConnectorWorker: # P TP > D TP case, return the ratio as negative return -remote_tp_size // self.tp_size - def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool: + def is_kv_replicated( + self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None + ) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. @@ -549,11 +553,14 @@ class NixlConnectorWorker: tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool: + def replicates_kv_cache( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> bool: # MLA is always replicated as the hidden dim can't be split. - return ( - self.is_mla - or self.is_kv_replicated(remote_engine_id, remote_tp_size) + return self.is_mla or self.is_kv_replicated( + remote_engine_id, remote_tp_size ) def get_target_remote_ranks( @@ -563,7 +570,8 @@ class NixlConnectorWorker: ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. + (on D) will read from. When remote tp_size > local tp_size, we + read from multiple remote ranks. """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) if tp_ratio > 0: @@ -573,8 +581,8 @@ class NixlConnectorWorker: tp_ratio = -tp_ratio if self.replicates_kv_cache(remote_engine_id, remote_tp_size): # When cache is replicated on remote, we only need to read - # from one remote. - return [self.tp_rank*tp_ratio] + # from one remote (they all have the same cache). + return [self.tp_rank * tp_ratio] return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): @@ -672,12 +680,10 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 - # TODO flexible enough to handle different P TP destinations? - # tp_ratio->handles - # Only poulated during handshake when we read from multiple sources + # Populated dynamically during handshake based on remote configuration. + # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_side_chunked_handles: dict[int, list[int]] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - # TODO do I need tp_Ratio of this? self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( dict ) @@ -1033,7 +1039,9 @@ class NixlConnectorWorker: self.tp_rank, ) - descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type) + descs = self.nixl_wrapper.get_xfer_descs( + self.src_blocks_data, self.nixl_memory_type + ) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs @@ -1093,10 +1101,12 @@ class NixlConnectorWorker: In particular, handle both homogeneous and heterogeneous TP. The former requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or - more local TP worker share the xfer from a single TP worker. + The latter, in the case of D.world_size < P.world_size, requires that a + local (D) TP worker reads from multiple remote (P) TP workers. + Conversely, assuming D.world_size > P.world_size, two or more local TP + workers will read from a single remote TP worker. - Here's an example (non-MLA case): + Here's an example for the last case described above (non-MLA): rank_offset p_remote_tp_rank (kv split no) @@ -1155,35 +1165,36 @@ class NixlConnectorWorker: ) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, - # we don't need to use this for splitting the remote kv cache. + # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, + # this is the ratio between the two sizes. tp_ratio = self.kv_info.tp_ratio(engine_id) # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \ - and tp_ratio < 0) + indexes_into_remote = ( + not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0 + ) - # When you realize you're in P TP>DTP you have to split your regions + ### (Optional) Register local agent memory regions if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: - # TODO use positive tp_ratio value? + # Remote tp_size > local tp_size: read from multiple remote ranks. + # Logically "split" own regions into |tp_ratio| chunks. Mind that + # we only do this once per remote tp_size (replica-friendly). self.src_xfer_side_chunked_handles[tp_ratio] = [] - # This is still needed even for MLA - # TODO actually only needs one!! - # Check if we have a split we can re-use, ie a remote P with same tp_ratio - for i in range(-tp_ratio): + # MLA-optimization: only prepare one region. + # NOTE NickLucche: only a chunk of whole cache is used with MLA! + tp_ratio_opt = 1 if self.use_mla else -tp_ratio + for i in range(tp_ratio_opt): blocks_data = [] for memory_region in self.src_blocks_data: addr, local_block_len, own_tp_rank = memory_region - # Computing block len layer by layer allow for different - # block sizes per layer - # TODO this needs to be an assert when validating - # TODO is this the right dim we're splitting on? H? - remote_block_len = local_block_len//(-tp_ratio) - # Offset + # Computing block len layer by layer allows for different + # block sizes to be used. + remote_block_len = local_block_len // (-tp_ratio) addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank? - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + blocks_data.append((addr, remote_block_len, own_tp_rank)) + descs = self.nixl_wrapper.get_xfer_descs( + blocks_data, self.nixl_memory_type + ) handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) self.src_xfer_side_chunked_handles[tp_ratio].append(handle) @@ -1196,8 +1207,11 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # TODO workaround - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 + # Read our whole local region size from remote. + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + if tp_ratio < 0: + # Remote tp is bigger: read a chunk of local region from remote + kv_block_len = kv_block_len // (-tp_ratio) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 ) @@ -1262,7 +1276,7 @@ class NixlConnectorWorker: ) remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: - if tp_ratio > 1 and self.device_type == "xpu": + if tp_ratio != 1 and self.device_type == "xpu": # XPU uses NHD, hence it does not support splitting on H raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length @@ -1270,26 +1284,42 @@ class NixlConnectorWorker: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) + + if tp_ratio > 0: + # Remote NHD/H'D*tp_ratio=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + # Remote tp is smaller: remote block_len size is bigger + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + else: + # Remote NHD/(H'D/tp_ratio)=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] // (-tp_ratio) + ) + # Remote tp is bigger: remote block_len size is smaller + assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 - # TODO add asserts for P TP > D TP - # assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - # "Remote P worker KV layer cache must be of shape [2, N, " - # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - # ) - # assert self.block_size == remote_block_size, ( - # "Remote P worker with different page/block size is not supported " - # f"{self.block_size=}, {remote_block_size=}" - # ) + # We may allow it in the future with logical kvcache manager block_size + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) - # # TP workers have same #blocks. - # assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - # assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # TP workers (handhshakes with same remote) have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + # Same number of regions/~layers. + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1421,7 +1451,7 @@ class NixlConnectorWorker: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - in_progress = False + in_progress = [] for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": @@ -1430,13 +1460,16 @@ class NixlConnectorWorker: self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": - in_progress = True + in_progress.append((handle, _xfer_stime)) continue else: raise RuntimeError("Transfer failed with state %s", xfer_state) if not in_progress: + # Only report request as completed when all transfers are done. done_req_ids.add(req_id) del transfers[req_id] + else: + transfers[req_id] = in_progress return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata): @@ -1490,7 +1523,8 @@ class NixlConnectorWorker: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) - # D may perform multiple reads from different remote ranks. + tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) + # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): logger.debug( "Remote agent %s available, calling _read_blocks" @@ -1499,13 +1533,17 @@ class NixlConnectorWorker: remote_rank, req_id, ) - # TODO refactor properly and ONLY DO THIS FOR PTP>DTP - tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) - # Get nixl desc handles depending on whether we're reading from - # multiple sources or we're reading a chunk of - local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] - remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank] - # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! + if tp_ratio < 0: + # Remote tp_size > local tp_size: we must perform multiple + # reads. Get the memory chunk onto which we will write to. + local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] + else: + # Single read from remote, we write to the whole memory region. + local_xfer_side_handle = self.src_xfer_side_handle + # Destination handle: remote_engine_id -> remote_rank -> handle. + remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][ + remote_rank + ] self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1526,6 +1564,10 @@ class NixlConnectorWorker: local_xfer_side_handle: int, remote_xfer_side_handle: int, ): + """ + Post a READ xfer request from a single local worker to a single + remote worker. + """ # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1543,7 +1585,7 @@ class NixlConnectorWorker: notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, - # just notify P worker(s) that we have the blocks we need. + # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: agent_name = self._remote_agents[dst_engine_id][remote_rank] @@ -1556,10 +1598,6 @@ class NixlConnectorWorker: if num_local_blocks < num_remote_blocks: remote_block_ids = remote_block_ids[-num_local_blocks:] - # Get side handles. - # local_xfer_side_handle = self.src_xfer_side_handle - # remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. @@ -1647,7 +1685,6 @@ class NixlConnectorWorker: assert self.num_layers == self.num_regions region_ids = np.arange(layer_idx, layer_idx + 1) - # TODO can this vary? num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. @@ -1694,11 +1731,10 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 - if self.src_xfer_side_chunked_handles: - for handles in self.src_xfer_side_chunked_handles.values(): - for handle in handles: - self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_side_chunked_handles.clear() + for handles in self.src_xfer_side_chunked_handles.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_side_chunked_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)