diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 762e47b7ae35a..1871138b088ca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -532,22 +532,29 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size else: + assert remote_tp_size % self.tp_size == 0, ( + f"Remote tensor parallel size {remote_tp_size} is not divisible " + f"by local tensor parallel size {self.tp_size}." + ) # P TP > D TP case, return the ratio as negative return -remote_tp_size // self.tp_size - def is_kv_replicated(self, engine_id: EngineId) -> bool: + def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. """ - tp_size = self.remote_tp_size[engine_id] + if tp_size is None: + assert engine_id is not None + tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + + def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool: # MLA is always replicated as the hidden dim can't be split. - # TODO docs - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size + return ( + self.is_mla + or self.is_kv_replicated(remote_engine_id, remote_tp_size) + ) def get_target_remote_ranks( self, @@ -564,7 +571,11 @@ class NixlConnectorWorker: else: # P TP > D TP case, D reads from |tp_ratio| remote workers. tp_ratio = -tp_ratio - return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)] + if self.replicates_kv_cache(remote_engine_id, remote_tp_size): + # When cache is replicated on remote, we only need to read + # from one remote. + return [self.tp_rank*tp_ratio] + return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: @@ -650,7 +661,9 @@ class NixlConnectorWorker: # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank may pull from multiple remote TP workers. - self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict) + self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = ( + defaultdict(dict) + ) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -659,9 +672,15 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 + # TODO flexible enough to handle different P TP destinations? + # tp_ratio->handles + # Only poulated during handshake when we read from multiple sources + self.src_xfer_side_chunked_handles: dict[int, list[int]] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. # TODO do I need tp_Ratio of this? - self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict) + self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( + dict + ) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -771,7 +790,7 @@ class NixlConnectorWorker: # When target instance TP > local TP, we need to perform multiple # handshakes. Do it in a single background job for simplicity. # Regardless, only handshake with the remote TP rank(s) that current - # local rank will read from. Note that With homogeneous TP, + # local rank will read from. Note that With homogeneous TP, # this happens to be the same single rank_i. p_remote_ranks = self.kv_info.get_target_remote_ranks( remote_tp_size=remote_tp_size @@ -791,7 +810,8 @@ class NixlConnectorWorker: metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() logger.debug( - "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + "NIXL handshake: get metadata took: %s", + got_metadata_time - start_time, ) # Ensure engine id matches. @@ -812,7 +832,6 @@ class NixlConnectorWorker: setup_agent_time - got_metadata_time, ) remote_rank_to_agent_name[remote_rank] = remote_agent_name - return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: @@ -983,7 +1002,7 @@ class NixlConnectorWorker: self.num_regions *= 2 # Register local/src descr for NIXL xfer. - blocks_data = [] + self.src_blocks_data = [] for i, base_addr in enumerate(seen_base_addresses): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are @@ -995,7 +1014,7 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.tp_rank)) + self.src_blocks_data.append((addr, kv_block_len, self.tp_rank)) if self._use_flashinfer: # Separate and interleave K/V regions to maintain the same @@ -1006,15 +1025,15 @@ class NixlConnectorWorker: addr = base_addr + block_offset # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.tp_rank)) + self.src_blocks_data.append((v_addr, kv_block_len, self.tp_rank)) logger.debug( "Created %s blocks for src engine %s and rank %s", - len(blocks_data), + len(self.src_blocks_data), self.engine_id, self.tp_rank, ) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs @@ -1125,24 +1144,49 @@ class NixlConnectorWorker: nixl_agent_meta.agent_metadata ) - # Handle tp_size>num_kv_heads: replicate KV cache. - replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) - print("REPLICATES KV CACHE", replicates_kv_cache, "\n") - # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( + nixl_agent_meta.kv_caches_base_addr + ) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, - # we don't need to use this for spliting the remote kv cache. + # we don't need to use this for splitting the remote kv cache. tp_ratio = self.kv_info.tp_ratio(engine_id) + # Handle tp_size>num_kv_heads: replicate KV cache. + indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \ + and tp_ratio < 0) + + # When you realize you're in P TP>DTP you have to split your regions + if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: + # TODO use positive tp_ratio value? + self.src_xfer_side_chunked_handles[tp_ratio] = [] + # This is still needed even for MLA + # TODO actually only needs one!! + # Check if we have a split we can re-use, ie a remote P with same tp_ratio + for i in range(-tp_ratio): + blocks_data = [] + for memory_region in self.src_blocks_data: + addr, local_block_len, own_tp_rank = memory_region + # Computing block len layer by layer allow for different + # block sizes per layer + # TODO this needs to be an assert when validating + # TODO is this the right dim we're splitting on? H? + remote_block_len = local_block_len//(-tp_ratio) + # Offset + addr = addr + i * remote_block_len + blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank? + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_side_chunked_handles[tp_ratio].append(handle) + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding @@ -1152,10 +1196,10 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # TODO + # TODO workaround kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 + self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1184,8 +1228,8 @@ class NixlConnectorWorker: # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs + self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( + self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs) ) return remote_agent_name @@ -1447,7 +1491,7 @@ class NixlConnectorWorker: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) # D may perform multiple reads from different remote ranks. - for remote_rank in remote_ranks: + for i, remote_rank in enumerate(remote_ranks): logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s for req %s", @@ -1455,6 +1499,12 @@ class NixlConnectorWorker: remote_rank, req_id, ) + # TODO refactor properly and ONLY DO THIS FOR PTP>DTP + tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) + # Get nixl desc handles depending on whether we're reading from + # multiple sources or we're reading a chunk of + local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] + remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank] # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! self._read_blocks( request_id=req_id, @@ -1462,6 +1512,8 @@ class NixlConnectorWorker: local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, remote_rank=remote_rank, + local_xfer_side_handle=local_xfer_side_handle, + remote_xfer_side_handle=remote_xfer_side_handle, ) def _read_blocks( @@ -1471,6 +1523,8 @@ class NixlConnectorWorker: dst_engine_id: str, request_id: str, remote_rank: int, + local_xfer_side_handle: int, + remote_xfer_side_handle: int, ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1503,8 +1557,8 @@ class NixlConnectorWorker: remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] + # local_xfer_side_handle = self.src_xfer_side_handle + # remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp @@ -1640,6 +1694,11 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 + if self.src_xfer_side_chunked_handles: + for handles in self.src_xfer_side_chunked_handles.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_side_chunked_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)