From 1a1c81ca2f8c9c4ec1f4fdce9da57833803ccd0c Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 10:29:47 +0000 Subject: [PATCH 1/9] init Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 197 ++++++++++++------ 1 file changed, 128 insertions(+), 69 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e3e3389fd1643..43c446d7f912b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group, ) -from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform @@ -504,6 +503,39 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" + @dataclass + class KVInfo: + tp_size: int + tp_rank: int + remote_tp_size: dict[EngineId, int] + is_mla: bool + total_num_kv_heads: int + + def tp_ratio( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> int: + if remote_tp_size is None: + assert remote_engine_id is not None + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_size // remote_tp_size + + def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: + tp_size = self.remote_tp_size[remote_engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> int: + tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) + return self.tp_rank // tp_ratio + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -627,7 +659,6 @@ class NixlConnectorWorker: # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() - self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -659,6 +690,14 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() + self.kv_info = self.KVInfo( + tp_size=self.world_size, + tp_rank=self.tp_rank, + remote_tp_size=self._tp_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + ) + @staticmethod def _nixl_handshake_listener( metadata: NixlAgentMetadata, @@ -704,8 +743,9 @@ class NixlConnectorWorker: # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio + p_remote_rank = self.kv_info.get_target_remote_rank( + remote_tp_size=remote_tp_size + ) path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug( "Querying metadata on path: %s at remote rank %s", path, p_remote_rank @@ -950,13 +990,11 @@ class NixlConnectorWorker: # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. - if self.vllm_config.model_config.hf_config.model_type == "llama4": + if self.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance( - self.vllm_config.model_config.hf_text_config, Llama4TextConfig - ) - llama4_config = self.vllm_config.model_config.hf_text_config + assert isinstance(self.model_config.hf_text_config, Llama4TextConfig) + llama4_config = self.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size chunk_block_size = math.ceil(chunk_size / self.block_size) @@ -1039,87 +1077,50 @@ class NixlConnectorWorker: engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): + logger.debug( + "Remote agent with engine_id %s and rank" + "%s already exchanged metadata, skip handshake.", + engine_id, + remote_tp_rank, + ) return self._remote_agents[engine_id][remote_tp_rank] + ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size - else: - assert self._tp_size[engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) - # Handle tp_size>num_kv_heads: replicate KV cache. - total_num_kv_heads = self.model_config.get_total_num_kv_heads() - is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) - remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or is_kv_replicated: - # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) - else: - # When MLA is not used, this is a list of the same block length - for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, ( - "All remote layers must have the same block size" - ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - if tp_ratio > 1: - # Heterogeneous TP expects same kv_cache_layout. - assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout - if self.device_type == "xpu": - raise ValueError("Heterogeneous TP is not supported on XPU") - - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) - - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - - # Create dst descs and xfer side handles. TP workers have same #blocks. - if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - else: + # Create dst descs and xfer side handles. TP workers have same #blocks + # so we only register once per engine_id. + if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # Keep track of remote agent kv caches base addresses. + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self.kv_info.tp_ratio(engine_id) + + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len - if not (self.use_mla or is_kv_replicated) - else 0 + self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1154,6 +1155,64 @@ class NixlConnectorWorker: return remote_agent_name + def _validate_remote_agent_handshake( + self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + ): + """ + Validate the remote agent handshake metadata ensuring the + invariants hold true. + """ + remote_engine_id = nixl_agent_meta.engine_id + + assert self._tp_size[remote_engine_id] == remote_tp_size + # TODO We may eventually want to skip enforcing the same attn backend. + assert nixl_agent_meta.attn_backend_name == self.backend_name + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + + tp_ratio = self.kv_info.tp_ratio(remote_engine_id) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) + if tp_ratio > 1 and self.device_type == "xpu": + raise ValueError("Heterogeneous TP is not supported on XPU") + + # Block len can only vary across layers when using MLA. + remote_block_len = nixl_agent_meta.block_lens[0] + if self.use_mla or self.kv_info.is_kv_replicated(remote_engine_id): + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + else: + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) + + # TP workers have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer @@ -1383,14 +1442,14 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self.kv_info.tp_ratio(dst_engine_id) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.tp_rank // tp_ratio + remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id) agent_name = self._remote_agents[dst_engine_id][remote_rank] self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return From 84dfd367a11862a65e60ccab4809924d9e9ec314 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 12:56:04 +0000 Subject: [PATCH 2/9] review Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 43c446d7f912b..ca04f5565411c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -519,6 +519,10 @@ class NixlConnectorWorker: if remote_tp_size is None: assert remote_engine_id is not None remote_tp_size = self.remote_tp_size[remote_engine_id] + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) return self.tp_size // remote_tp_size def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: @@ -1174,8 +1178,6 @@ class NixlConnectorWorker: assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) - if tp_ratio > 1 and self.device_type == "xpu": - raise ValueError("Heterogeneous TP is not supported on XPU") # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] @@ -1186,6 +1188,9 @@ class NixlConnectorWorker: ) remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: + if tp_ratio > 1 and self.device_type == "xpu": + # XPU uses NHD, hence it does not support splitting on H + raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( From 5d45b77124a1e8e464274a1e81d4c3106f02597a Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 14:12:35 +0000 Subject: [PATCH 3/9] docs Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ca04f5565411c..ca53d1df92aec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -516,6 +516,12 @@ class NixlConnectorWorker: remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ if remote_tp_size is None: assert remote_engine_id is not None remote_tp_size = self.remote_tp_size[remote_engine_id] @@ -525,11 +531,16 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size - def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: - tp_size = self.remote_tp_size[remote_engine_id] + def is_kv_replicated(self, engine_id: EngineId) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + # MLA is always replicated as the hidden dim can't be split. return self.is_mla or self.is_kv_replicated(remote_engine_id) def get_target_remote_rank( @@ -537,6 +548,10 @@ class NixlConnectorWorker: remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) return self.tp_rank // tp_ratio From 684c9b7b6dc127928930483ec93de48d2ce532ce Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 17:50:07 +0000 Subject: [PATCH 4/9] init Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 202 ++++++++++-------- 1 file changed, 113 insertions(+), 89 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ca53d1df92aec..762e47b7ae35a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -525,11 +525,15 @@ class NixlConnectorWorker: if remote_tp_size is None: assert remote_engine_id is not None remote_tp_size = self.remote_tp_size[remote_engine_id] - assert self.tp_size % remote_tp_size == 0, ( - f"Local tensor parallel size {self.tp_size} is not divisible " - f"by remote tensor parallel size {remote_tp_size}." - ) - return self.tp_size // remote_tp_size + if self.tp_size >= remote_tp_size: + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + else: + # P TP > D TP case, return the ratio as negative + return -remote_tp_size // self.tp_size def is_kv_replicated(self, engine_id: EngineId) -> bool: """ @@ -538,22 +542,29 @@ class NixlConnectorWorker: """ tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) + # TODO docs + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size - def get_target_remote_rank( + def get_target_remote_ranks( self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, - ) -> int: + ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank (on D) will read from. """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) - return self.tp_rank // tp_ratio + if tp_ratio > 0: + return [self.tp_rank // tp_ratio] + else: + # P TP > D TP case, D reads from |tp_ratio| remote workers. + tp_ratio = -tp_ratio + return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: @@ -638,8 +649,8 @@ class NixlConnectorWorker: self.copy_blocks: Optional[CopyBlocksOp] = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local - # rank will still only pull from a single remote TP worker. - self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + # rank may pull from multiple remote TP workers. + self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -649,7 +660,8 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[EngineId, int] = {} + # TODO do I need tp_Ratio of this? + self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -756,51 +768,52 @@ class NixlConnectorWorker: start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - - # Handshake only with the remote TP rank that current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - p_remote_rank = self.kv_info.get_target_remote_rank( + # When target instance TP > local TP, we need to perform multiple + # handshakes. Do it in a single background job for simplicity. + # Regardless, only handshake with the remote TP rank(s) that current + # local rank will read from. Note that With homogeneous TP, + # this happens to be the same single rank_i. + p_remote_ranks = self.kv_info.get_target_remote_ranks( remote_tp_size=remote_tp_size ) - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug( - "Querying metadata on path: %s at remote rank %s", path, p_remote_rank - ) - - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - logger.debug( - "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + remote_rank_to_agent_name = {} + for remote_rank in p_remote_ranks: + path = make_zmq_path("tcp", host, port + remote_rank) + logger.warning( + "Querying metadata on path: %s at remote rank %s", path, remote_rank ) - # Ensure engine id matches. - if metadata.engine_id != expected_engine_id: - raise RuntimeError( - f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}." + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.debug( + "NIXL handshake: get metadata took: %s", got_metadata_time - start_time ) - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, p_remote_rank, remote_tp_size - ) - setup_agent_time = time.perf_counter() - logger.debug( - "NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time, - ) + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) - # Remote rank -> agent name. - return {p_remote_rank: remote_agent_name} + # Register Remote agent. + remote_agent_name = self.add_remote_agent( + metadata, remote_rank, remote_tp_size + ) + setup_agent_time = time.perf_counter() + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) + remote_rank_to_agent_name[remote_rank] = remote_agent_name + + return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ @@ -944,7 +957,7 @@ class NixlConnectorWorker: assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 - self.kv_caches_base_addr[self.engine_id] = seen_base_addresses + self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1033,7 +1046,7 @@ class NixlConnectorWorker: metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -1096,7 +1109,7 @@ class NixlConnectorWorker: engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): - logger.debug( + logger.warning( "Remote agent with engine_id %s and rank" "%s already exchanged metadata, skip handshake.", engine_id, @@ -1114,6 +1127,7 @@ class NixlConnectorWorker: # Handle tp_size>num_kv_heads: replicate KV cache. replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) + print("REPLICATES KV CACHE", replicates_kv_cache, "\n") # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. @@ -1121,11 +1135,12 @@ class NixlConnectorWorker: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. + # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, + # we don't need to use this for spliting the remote kv cache. tp_ratio = self.kv_info.tp_ratio(engine_id) ### Register remote agent memory regions @@ -1137,7 +1152,8 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + # TODO + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) @@ -1168,7 +1184,7 @@ class NixlConnectorWorker: # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist( remote_agent_name, descs ) @@ -1189,7 +1205,6 @@ class NixlConnectorWorker: assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout tp_ratio = self.kv_info.tp_ratio(remote_engine_id) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) @@ -1217,21 +1232,20 @@ class NixlConnectorWorker: if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 + # TODO add asserts for P TP > D TP + # assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + # "Remote P worker KV layer cache must be of shape [2, N, " + # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + # ) - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) + # assert self.block_size == remote_block_size, ( + # "Remote P worker with different page/block size is not supported " + # f"{self.block_size=}, {remote_block_size=}" + # ) - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - - # TP workers have same #blocks. - assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # # TP workers have same #blocks. + # assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + # assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1431,17 +1445,24 @@ class NixlConnectorWorker: self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - logger.debug( - "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, - req_id, - ) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - ) + remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) + # D may perform multiple reads from different remote ranks. + for remote_rank in remote_ranks: + logger.debug( + "Remote agent %s available, calling _read_blocks" + " on remote rank %s for req %s", + meta.remote_engine_id, + remote_rank, + req_id, + ) + # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_rank=remote_rank, + ) def _read_blocks( self, @@ -1449,6 +1470,7 @@ class NixlConnectorWorker: remote_block_ids: list[int], dst_engine_id: str, request_id: str, + remote_rank: int, ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1462,14 +1484,14 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self.kv_info.tp_ratio(dst_engine_id) + # Cap to 1 when P TP > D TP: only a single rank will read from remote. + tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id)) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, - # just notify P worker that we have the blocks we need. + # just notify P worker(s) that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id) agent_name = self._remote_agents[dst_engine_id][remote_rank] self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return @@ -1482,7 +1504,7 @@ class NixlConnectorWorker: # Get side handles. local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp @@ -1571,6 +1593,7 @@ class NixlConnectorWorker: assert self.num_layers == self.num_regions region_ids = np.arange(layer_idx, layer_idx + 1) + # TODO can this vary? num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. @@ -1617,8 +1640,9 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 - for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) self.dst_xfer_side_handles.clear() for remote_agents in self._remote_agents.values(): for agent_name in remote_agents.values(): From 7bb3861faf17b6d6650f8287d2e43c82ac369128 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Oct 2025 13:24:04 +0000 Subject: [PATCH 5/9] hacky Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 121 +++++++++++++----- 1 file changed, 90 insertions(+), 31 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 762e47b7ae35a..1871138b088ca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -532,22 +532,29 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size else: + assert remote_tp_size % self.tp_size == 0, ( + f"Remote tensor parallel size {remote_tp_size} is not divisible " + f"by local tensor parallel size {self.tp_size}." + ) # P TP > D TP case, return the ratio as negative return -remote_tp_size // self.tp_size - def is_kv_replicated(self, engine_id: EngineId) -> bool: + def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. """ - tp_size = self.remote_tp_size[engine_id] + if tp_size is None: + assert engine_id is not None + tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + + def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool: # MLA is always replicated as the hidden dim can't be split. - # TODO docs - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size + return ( + self.is_mla + or self.is_kv_replicated(remote_engine_id, remote_tp_size) + ) def get_target_remote_ranks( self, @@ -564,7 +571,11 @@ class NixlConnectorWorker: else: # P TP > D TP case, D reads from |tp_ratio| remote workers. tp_ratio = -tp_ratio - return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)] + if self.replicates_kv_cache(remote_engine_id, remote_tp_size): + # When cache is replicated on remote, we only need to read + # from one remote. + return [self.tp_rank*tp_ratio] + return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: @@ -650,7 +661,9 @@ class NixlConnectorWorker: # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank may pull from multiple remote TP workers. - self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict) + self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = ( + defaultdict(dict) + ) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -659,9 +672,15 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 + # TODO flexible enough to handle different P TP destinations? + # tp_ratio->handles + # Only poulated during handshake when we read from multiple sources + self.src_xfer_side_chunked_handles: dict[int, list[int]] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. # TODO do I need tp_Ratio of this? - self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict) + self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( + dict + ) # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -771,7 +790,7 @@ class NixlConnectorWorker: # When target instance TP > local TP, we need to perform multiple # handshakes. Do it in a single background job for simplicity. # Regardless, only handshake with the remote TP rank(s) that current - # local rank will read from. Note that With homogeneous TP, + # local rank will read from. Note that With homogeneous TP, # this happens to be the same single rank_i. p_remote_ranks = self.kv_info.get_target_remote_ranks( remote_tp_size=remote_tp_size @@ -791,7 +810,8 @@ class NixlConnectorWorker: metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() logger.debug( - "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + "NIXL handshake: get metadata took: %s", + got_metadata_time - start_time, ) # Ensure engine id matches. @@ -812,7 +832,6 @@ class NixlConnectorWorker: setup_agent_time - got_metadata_time, ) remote_rank_to_agent_name[remote_rank] = remote_agent_name - return remote_rank_to_agent_name def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: @@ -983,7 +1002,7 @@ class NixlConnectorWorker: self.num_regions *= 2 # Register local/src descr for NIXL xfer. - blocks_data = [] + self.src_blocks_data = [] for i, base_addr in enumerate(seen_base_addresses): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # NOTE With heter-TP, more blocks are prepared than what are @@ -995,7 +1014,7 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len_per_layer[i] addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.tp_rank)) + self.src_blocks_data.append((addr, kv_block_len, self.tp_rank)) if self._use_flashinfer: # Separate and interleave K/V regions to maintain the same @@ -1006,15 +1025,15 @@ class NixlConnectorWorker: addr = base_addr + block_offset # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.tp_rank)) + self.src_blocks_data.append((v_addr, kv_block_len, self.tp_rank)) logger.debug( "Created %s blocks for src engine %s and rank %s", - len(blocks_data), + len(self.src_blocks_data), self.engine_id, self.tp_rank, ) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs @@ -1125,24 +1144,49 @@ class NixlConnectorWorker: nixl_agent_meta.agent_metadata ) - # Handle tp_size>num_kv_heads: replicate KV cache. - replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) - print("REPLICATES KV CACHE", replicates_kv_cache, "\n") - # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( + nixl_agent_meta.kv_caches_base_addr + ) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, - # we don't need to use this for spliting the remote kv cache. + # we don't need to use this for splitting the remote kv cache. tp_ratio = self.kv_info.tp_ratio(engine_id) + # Handle tp_size>num_kv_heads: replicate KV cache. + indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \ + and tp_ratio < 0) + + # When you realize you're in P TP>DTP you have to split your regions + if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: + # TODO use positive tp_ratio value? + self.src_xfer_side_chunked_handles[tp_ratio] = [] + # This is still needed even for MLA + # TODO actually only needs one!! + # Check if we have a split we can re-use, ie a remote P with same tp_ratio + for i in range(-tp_ratio): + blocks_data = [] + for memory_region in self.src_blocks_data: + addr, local_block_len, own_tp_rank = memory_region + # Computing block len layer by layer allow for different + # block sizes per layer + # TODO this needs to be an assert when validating + # TODO is this the right dim we're splitting on? H? + remote_block_len = local_block_len//(-tp_ratio) + # Offset + addr = addr + i * remote_block_len + blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank? + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + self.src_xfer_side_chunked_handles[tp_ratio].append(handle) + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding @@ -1152,10 +1196,10 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # TODO + # TODO workaround kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 + self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1184,8 +1228,8 @@ class NixlConnectorWorker: # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs + self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( + self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs) ) return remote_agent_name @@ -1447,7 +1491,7 @@ class NixlConnectorWorker: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) # D may perform multiple reads from different remote ranks. - for remote_rank in remote_ranks: + for i, remote_rank in enumerate(remote_ranks): logger.debug( "Remote agent %s available, calling _read_blocks" " on remote rank %s for req %s", @@ -1455,6 +1499,12 @@ class NixlConnectorWorker: remote_rank, req_id, ) + # TODO refactor properly and ONLY DO THIS FOR PTP>DTP + tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) + # Get nixl desc handles depending on whether we're reading from + # multiple sources or we're reading a chunk of + local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] + remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank] # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! self._read_blocks( request_id=req_id, @@ -1462,6 +1512,8 @@ class NixlConnectorWorker: local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, remote_rank=remote_rank, + local_xfer_side_handle=local_xfer_side_handle, + remote_xfer_side_handle=remote_xfer_side_handle, ) def _read_blocks( @@ -1471,6 +1523,8 @@ class NixlConnectorWorker: dst_engine_id: str, request_id: str, remote_rank: int, + local_xfer_side_handle: int, + remote_xfer_side_handle: int, ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1503,8 +1557,8 @@ class NixlConnectorWorker: remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle - remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] + # local_xfer_side_handle = self.src_xfer_side_handle + # remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp @@ -1640,6 +1694,11 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 + if self.src_xfer_side_chunked_handles: + for handles in self.src_xfer_side_chunked_handles.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_side_chunked_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) From 9f38fed93ccd57a4d152832ee30878eca0654fd6 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Oct 2025 15:43:43 +0000 Subject: [PATCH 6/9] clean up Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 186 +++++++++++------- 1 file changed, 111 insertions(+), 75 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1871138b088ca..66d692e7387da 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -520,7 +520,9 @@ class NixlConnectorWorker: Calculate the tensor parallel ratio between local and remote TP. We can think of it as the number of local TP workers-per-remote TP workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`. + groups of size `tp_ratio`. If remote tp_size > local tp_size, the + ratio is flipped (remote_size/local_size) and the returned value is + negative. """ if remote_tp_size is None: assert remote_engine_id is not None @@ -539,7 +541,9 @@ class NixlConnectorWorker: # P TP > D TP case, return the ratio as negative return -remote_tp_size // self.tp_size - def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool: + def is_kv_replicated( + self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None + ) -> bool: """ Whether the KV cache is replicated across TP workers due to the number of TP workers being greater than the number of KV heads. @@ -549,11 +553,14 @@ class NixlConnectorWorker: tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 - def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool: + def replicates_kv_cache( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> bool: # MLA is always replicated as the hidden dim can't be split. - return ( - self.is_mla - or self.is_kv_replicated(remote_engine_id, remote_tp_size) + return self.is_mla or self.is_kv_replicated( + remote_engine_id, remote_tp_size ) def get_target_remote_ranks( @@ -563,7 +570,8 @@ class NixlConnectorWorker: ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. + (on D) will read from. When remote tp_size > local tp_size, we + read from multiple remote ranks. """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) if tp_ratio > 0: @@ -573,8 +581,8 @@ class NixlConnectorWorker: tp_ratio = -tp_ratio if self.replicates_kv_cache(remote_engine_id, remote_tp_size): # When cache is replicated on remote, we only need to read - # from one remote. - return [self.tp_rank*tp_ratio] + # from one remote (they all have the same cache). + return [self.tp_rank * tp_ratio] return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] def __init__(self, vllm_config: VllmConfig, engine_id: str): @@ -672,12 +680,10 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 - # TODO flexible enough to handle different P TP destinations? - # tp_ratio->handles - # Only poulated during handshake when we read from multiple sources + # Populated dynamically during handshake based on remote configuration. + # Keep track of regions at different tp_ratio values. tp_ratio->handles self.src_xfer_side_chunked_handles: dict[int, list[int]] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - # TODO do I need tp_Ratio of this? self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict( dict ) @@ -1033,7 +1039,9 @@ class NixlConnectorWorker: self.tp_rank, ) - descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type) + descs = self.nixl_wrapper.get_xfer_descs( + self.src_blocks_data, self.nixl_memory_type + ) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs @@ -1093,10 +1101,12 @@ class NixlConnectorWorker: In particular, handle both homogeneous and heterogeneous TP. The former requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or - more local TP worker share the xfer from a single TP worker. + The latter, in the case of D.world_size < P.world_size, requires that a + local (D) TP worker reads from multiple remote (P) TP workers. + Conversely, assuming D.world_size > P.world_size, two or more local TP + workers will read from a single remote TP worker. - Here's an example (non-MLA case): + Here's an example for the last case described above (non-MLA): rank_offset p_remote_tp_rank (kv split no) @@ -1155,35 +1165,36 @@ class NixlConnectorWorker: ) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. If P TP > D TP, - # we don't need to use this for splitting the remote kv cache. + # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, + # this is the ratio between the two sizes. tp_ratio = self.kv_info.tp_ratio(engine_id) # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \ - and tp_ratio < 0) + indexes_into_remote = ( + not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0 + ) - # When you realize you're in P TP>DTP you have to split your regions + ### (Optional) Register local agent memory regions if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: - # TODO use positive tp_ratio value? + # Remote tp_size > local tp_size: read from multiple remote ranks. + # Logically "split" own regions into |tp_ratio| chunks. Mind that + # we only do this once per remote tp_size (replica-friendly). self.src_xfer_side_chunked_handles[tp_ratio] = [] - # This is still needed even for MLA - # TODO actually only needs one!! - # Check if we have a split we can re-use, ie a remote P with same tp_ratio - for i in range(-tp_ratio): + # MLA-optimization: only prepare one region. + # NOTE NickLucche: only a chunk of whole cache is used with MLA! + tp_ratio_opt = 1 if self.use_mla else -tp_ratio + for i in range(tp_ratio_opt): blocks_data = [] for memory_region in self.src_blocks_data: addr, local_block_len, own_tp_rank = memory_region - # Computing block len layer by layer allow for different - # block sizes per layer - # TODO this needs to be an assert when validating - # TODO is this the right dim we're splitting on? H? - remote_block_len = local_block_len//(-tp_ratio) - # Offset + # Computing block len layer by layer allows for different + # block sizes to be used. + remote_block_len = local_block_len // (-tp_ratio) addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank? - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + blocks_data.append((addr, remote_block_len, own_tp_rank)) + descs = self.nixl_wrapper.get_xfer_descs( + blocks_data, self.nixl_memory_type + ) handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) self.src_xfer_side_chunked_handles[tp_ratio].append(handle) @@ -1196,8 +1207,11 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # TODO workaround - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2 + # Read our whole local region size from remote. + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + if tp_ratio < 0: + # Remote tp is bigger: read a chunk of local region from remote + kv_block_len = kv_block_len // (-tp_ratio) rank_offset = ( self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 ) @@ -1262,7 +1276,7 @@ class NixlConnectorWorker: ) remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: - if tp_ratio > 1 and self.device_type == "xpu": + if tp_ratio != 1 and self.device_type == "xpu": # XPU uses NHD, hence it does not support splitting on H raise ValueError("Heterogeneous TP is not supported on XPU") # When MLA is not used, this is a list of the same block length @@ -1270,26 +1284,42 @@ class NixlConnectorWorker: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) + + if tp_ratio > 0: + # Remote NHD/H'D*tp_ratio=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + # Remote tp is smaller: remote block_len size is bigger + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + else: + # Remote NHD/(H'D/tp_ratio)=N -page_size- + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] // (-tp_ratio) + ) + # Remote tp is bigger: remote block_len size is smaller + assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." + ) # noqa: E501 + if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 - # TODO add asserts for P TP > D TP - # assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - # "Remote P worker KV layer cache must be of shape [2, N, " - # "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - # ) - # assert self.block_size == remote_block_size, ( - # "Remote P worker with different page/block size is not supported " - # f"{self.block_size=}, {remote_block_size=}" - # ) + # We may allow it in the future with logical kvcache manager block_size + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) - # # TP workers have same #blocks. - # assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - # assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + # TP workers (handhshakes with same remote) have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + # Same number of regions/~layers. + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1421,7 +1451,7 @@ class NixlConnectorWorker: """ done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): - in_progress = False + in_progress = [] for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": @@ -1430,13 +1460,16 @@ class NixlConnectorWorker: self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": - in_progress = True + in_progress.append((handle, _xfer_stime)) continue else: raise RuntimeError("Transfer failed with state %s", xfer_state) if not in_progress: + # Only report request as completed when all transfers are done. done_req_ids.add(req_id) del transfers[req_id] + else: + transfers[req_id] = in_progress return done_req_ids def start_load_kv(self, metadata: NixlConnectorMetadata): @@ -1490,7 +1523,8 @@ class NixlConnectorWorker: def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id) - # D may perform multiple reads from different remote ranks. + tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) + # D may have to perform multiple reads from different remote ranks. for i, remote_rank in enumerate(remote_ranks): logger.debug( "Remote agent %s available, calling _read_blocks" @@ -1499,13 +1533,17 @@ class NixlConnectorWorker: remote_rank, req_id, ) - # TODO refactor properly and ONLY DO THIS FOR PTP>DTP - tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id) - # Get nixl desc handles depending on whether we're reading from - # multiple sources or we're reading a chunk of - local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] - remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank] - # TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS! + if tp_ratio < 0: + # Remote tp_size > local tp_size: we must perform multiple + # reads. Get the memory chunk onto which we will write to. + local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] + else: + # Single read from remote, we write to the whole memory region. + local_xfer_side_handle = self.src_xfer_side_handle + # Destination handle: remote_engine_id -> remote_rank -> handle. + remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][ + remote_rank + ] self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1526,6 +1564,10 @@ class NixlConnectorWorker: local_xfer_side_handle: int, remote_xfer_side_handle: int, ): + """ + Post a READ xfer request from a single local worker to a single + remote worker. + """ # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1543,7 +1585,7 @@ class NixlConnectorWorker: notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, - # just notify P worker(s) that we have the blocks we need. + # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: agent_name = self._remote_agents[dst_engine_id][remote_rank] @@ -1556,10 +1598,6 @@ class NixlConnectorWorker: if num_local_blocks < num_remote_blocks: remote_block_ids = remote_block_ids[-num_local_blocks:] - # Get side handles. - # local_xfer_side_handle = self.src_xfer_side_handle - # remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank] - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. @@ -1647,7 +1685,6 @@ class NixlConnectorWorker: assert self.num_layers == self.num_regions region_ids = np.arange(layer_idx, layer_idx + 1) - # TODO can this vary? num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. @@ -1694,11 +1731,10 @@ class NixlConnectorWorker: if self.src_xfer_side_handle: self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.src_xfer_side_handle = 0 - if self.src_xfer_side_chunked_handles: - for handles in self.src_xfer_side_chunked_handles.values(): - for handle in handles: - self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_side_chunked_handles.clear() + for handles in self.src_xfer_side_chunked_handles.values(): + for handle in handles: + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_side_chunked_handles.clear() for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): for dst_xfer_side_handle in dst_xfer_side_handles.values(): self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) From 6601c9c5bec80ad1b8474a08f98f5f77c520687b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 10 Oct 2025 10:30:38 +0000 Subject: [PATCH 7/9] add and update tests Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 134 +++++++++++++++--- .../kv_connector/v1/nixl_connector.py | 13 +- 2 files changed, 123 insertions(+), 24 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a1f53cb255630..cff1f08845721 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -291,6 +291,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = "HND" def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -307,21 +308,42 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): assert expected_engine_id == self.REMOTE_ENGINE_ID - remote_agent_name = self.add_remote_agent( - NixlAgentMetadata( - engine_id=self.REMOTE_ENGINE_ID, - agent_metadata=FakeNixlWrapper.AGENT_METADATA, - kv_caches_base_addr=[0], - num_blocks=1, - block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", - ), - remote_tp_size=remote_tp_size, - ) - return {0: remote_agent_name} + # Adjust remote block length metadata to satisfy heterogeneous TP + # invariants enforced during handshake validation. + remote_block_lens = list(self.block_len_per_layer) + tp_ratio = self.kv_info.tp_ratio(remote_tp_size=remote_tp_size) + if remote_tp_size > self.world_size: + # P TP > D TP case, block_len of remote is smaller + remote_block_lens = [ + block_len // (-tp_ratio) for block_len in remote_block_lens + ] + elif remote_tp_size < self.world_size: + remote_block_lens = [ + block_len * tp_ratio for block_len in remote_block_lens + ] + + # When remote tp_size > local tp_size, handshake with multiple + # remote ranks. + num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio + remote_agents: dict[int, str] = {} + for remote_tp_rank in range(num_hanshakes): + remote_agent_name = self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + block_lens=remote_block_lens, + attn_backend_name=self.backend_name, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", + ), + remote_tp_rank=remote_tp_rank, + remote_tp_size=remote_tp_size, + ) + remote_agents[remote_tp_rank] = remote_agent_name + return remote_agents class TestNixlHandshake: @@ -352,7 +374,13 @@ class TestNixlHandshake: vllm_config, connector.engine_id, hand_shake_latency=0 ) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) - connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + worker = connector.connector_worker + worker.nixl_wrapper.set_cycles_before_xfer_done(3) + # simulate handshake + worker.dst_xfer_side_handles = { + FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} + } + worker.kv_cache_layout = "HND" num_xfers = 4 while True: # For the same request_id, initiate multiple xfers across different @@ -464,6 +492,70 @@ class TestNixlHandshake: return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + @pytest.mark.parametrize("local_tp_size", [1, 2]) + def test_prefill_tp_size_greater_than_decode_tp_size( + self, local_tp_size: int, dist_init + ): + """ + Verify remote TP > local TP handshake succeeds with different + remote configurations. + """ + + vllm_config = create_vllm_config() + local_tp_size = 1 + vllm_config.parallel_config.tensor_parallel_size = local_tp_size + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + + def check_handshake(remote_tp_size: int): + tp_ratio = remote_tp_size // local_tp_size + assert set(remote_agents.keys()) == set(range(tp_ratio)) + + remote_engine_id = worker.REMOTE_ENGINE_ID + assert worker._tp_size[remote_engine_id] == remote_tp_size + assert -tp_ratio == worker.kv_info.tp_ratio(remote_engine_id) + # ensure src_xfer_side_chunked_handles is populated with tpratio chunks + assert -tp_ratio in worker.src_xfer_side_chunked_handles + assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio + assert remote_engine_id in worker.dst_xfer_side_handles + assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set( + range(tp_ratio) + ) + + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=2, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(2) + + # NOTE flexiblity: a second remote with higher number of ranks + # is discovered + worker.REMOTE_ENGINE_ID = "remote_engine_2" + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=6, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(6) + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, @@ -564,10 +656,9 @@ class TestNixlHandshake: kv_cache_layout=mismatched_layout, ) - # We don't check layout for homogeneous TP and MLA for now, as the - # whole block is moved. - worker.add_remote_agent(meta, remote_tp_size=2) + # Layout check done for both homogeneous and heterogeneous TP. with pytest.raises(AssertionError): + worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=1) @@ -1057,7 +1148,8 @@ def test_shutdown_cleans_up_resources(dist_init): ): worker._recving_transfers = {"req1": [(123, time.perf_counter())]} worker.src_xfer_side_handle = 456 - worker.dst_xfer_side_handles = {"engine1": 789} + worker.src_xfer_side_chunked_handles = {-2: [456]} + worker.dst_xfer_side_handles = {"engine1": {0: 789}} worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] @@ -1071,7 +1163,7 @@ def test_shutdown_cleans_up_resources(dist_init): mock_listener.join.assert_called_once_with(timeout=0) mock_rel_xfer.assert_called_once_with(123) - assert mock_rel_dlist.call_count == 2 + assert mock_rel_dlist.call_count == 3 mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(789) # dst handle mock_rem_agent.assert_called_once_with("agent1") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 66d692e7387da..0f71addb67191 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1171,7 +1171,14 @@ class NixlConnectorWorker: # Handle tp_size>num_kv_heads: replicate KV cache. indexes_into_remote = ( - not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0 + not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio > 0 + ) + + logger.debug( + "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", + engine_id, + remote_tp_rank, + tp_ratio, ) ### (Optional) Register local agent memory regions @@ -1724,8 +1731,8 @@ class NixlConnectorWorker: if self._nixl_handshake_listener_t is not None: self._nixl_handshake_listener_t.join(timeout=0) self._nixl_handshake_listener_t = None - for handles in self._recving_transfers.values(): - for handle, _ in handles: + for rcv_handles in self._recving_transfers.values(): + for handle, _ in rcv_handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() if self.src_xfer_side_handle: From b8d520232fd455155112d6b2cf1a36b64da90aee Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 13 Oct 2025 14:01:34 +0000 Subject: [PATCH 8/9] fix mla Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0f71addb67191..83d0df41ead1c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1181,16 +1181,18 @@ class NixlConnectorWorker: tp_ratio, ) - ### (Optional) Register local agent memory regions - if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles: + ### (Optional) Register local agent memory regions. + # MLA-optimization: only prepare one region. + if ( + tp_ratio < 0 + and not self.use_mla + and tp_ratio not in self.src_xfer_side_chunked_handles + ): # Remote tp_size > local tp_size: read from multiple remote ranks. # Logically "split" own regions into |tp_ratio| chunks. Mind that # we only do this once per remote tp_size (replica-friendly). self.src_xfer_side_chunked_handles[tp_ratio] = [] - # MLA-optimization: only prepare one region. - # NOTE NickLucche: only a chunk of whole cache is used with MLA! - tp_ratio_opt = 1 if self.use_mla else -tp_ratio - for i in range(tp_ratio_opt): + for i in range(-tp_ratio): blocks_data = [] for memory_region in self.src_blocks_data: addr, local_block_len, own_tp_rank = memory_region @@ -1215,12 +1217,12 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): # Read our whole local region size from remote. - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - if tp_ratio < 0: + local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + if tp_ratio < 0 and not self.use_mla: # Remote tp is bigger: read a chunk of local region from remote - kv_block_len = kv_block_len // (-tp_ratio) + local_block_len = local_block_len // (-tp_ratio) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0 + self.tp_rank % tp_ratio * local_block_len if indexes_into_remote else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1229,7 +1231,7 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, remote_tp_rank)) + blocks_data.append((addr, local_block_len, remote_tp_rank)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. @@ -1237,7 +1239,7 @@ class NixlConnectorWorker: block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + blocks_data.append((v_addr, local_block_len, remote_tp_rank)) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", @@ -1273,10 +1275,12 @@ class NixlConnectorWorker: assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." ) + # Num kv_heads > tp_size and P TP > D TP case, not supported + assert not (tp_ratio < 0 and self.kv_info.is_kv_replicated(remote_engine_id)) # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or self.kv_info.is_kv_replicated(remote_engine_id): + if self.kv_info.replicates_kv_cache(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( "KV cache sizes must match between P and D when replicated" @@ -1540,7 +1544,7 @@ class NixlConnectorWorker: remote_rank, req_id, ) - if tp_ratio < 0: + if tp_ratio < 0 and not self.use_mla: # Remote tp_size > local tp_size: we must perform multiple # reads. Get the memory chunk onto which we will write to. local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i] From 1dc9df9842eca4563890a85e8b5eb2c2c551488c Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 13 Oct 2025 14:20:41 +0000 Subject: [PATCH 9/9] more integration tests Signed-off-by: NickLucche --- .../nixl_integration/run_accuracy_test.sh | 16 ++++--- .../nixl_integration/test_accuracy.py | 6 ++- .../tp_config_sweep_accuracy_test.sh | 43 +++++++++++++++++++ 3 files changed, 59 insertions(+), 6 deletions(-) create mode 100755 tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 3b0f2d102c1ff..78941e5edd4e3 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -27,15 +27,21 @@ else fi # Models to run -MODELS=( - "Qwen/Qwen3-0.6B" -) +MODEL_NAMES=${MODEL_NAMES:-} +if [[ -n "$MODEL_NAMES" ]]; then + MODELS=("$MODEL_NAMES") +else + MODELS=( + "Qwen/Qwen3-0.6B" + ) +fi # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} +GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -116,7 +122,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ - --gpu-memory-utilization 0.2 \ + --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -151,7 +157,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ - --gpu-memory-utilization 0.2 \ + --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $DECODER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index b301968e5bf84..8f4fbe6ff7a26 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -12,7 +12,11 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 # Model-specific expected values -EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59} +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, + "deepseek-ai/deepseek-vl2-small": 0.59, + "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, +} SIMPLE_PROMPT = ( "The best part about working on vLLM is that I got to meet so many people across " diff --git a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh new file mode 100755 index 0000000000000..d82c79081d7c8 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Utility to run integration tests sequentially with varying TP configurations. +# If FLASHINFER is set, reruns all tests with VLLM_ATTENTION_BACKEND=FLASHINFER. + +SCRIPT="tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh" + +# Define test configurations +configs=( + "PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" + "PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" + "PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1" + "GPU_MEMORY_UTILIZATION=0.6 MODEL_NAMES=deepseek-ai/DeepSeek-V2-Lite-Chat" # MLA case + # TP greater than num heads +) + +run_tests() { + local label=$1 + local extra_env=$2 + + echo "=== Running tests (${label}) ===" + for cfg in "${configs[@]}"; do + echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" + # Use 'env' to safely set variables without eval + if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then + echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" + exit 1 + fi + done + echo "✅ All ${label} tests passed!" +} + +# Run base tests +run_tests "default backend" "" + +# Check if FLASHINFER is set (non-empty) +if [[ -n "${FLASHINFER:-}" ]]; then + echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" + run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" +else + echo "FLASHINFER not set, skipping FLASHINFER runs." +fi