diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e3e3389fd1643..43c446d7f912b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group, ) -from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform @@ -504,6 +503,39 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" + @dataclass + class KVInfo: + tp_size: int + tp_rank: int + remote_tp_size: dict[EngineId, int] + is_mla: bool + total_num_kv_heads: int + + def tp_ratio( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> int: + if remote_tp_size is None: + assert remote_engine_id is not None + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_size // remote_tp_size + + def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: + tp_size = self.remote_tp_size[remote_engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_engine_id: Optional[EngineId] = None, + remote_tp_size: Optional[int] = None, + ) -> int: + tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) + return self.tp_rank // tp_ratio + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -627,7 +659,6 @@ class NixlConnectorWorker: # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() - self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -659,6 +690,14 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() + self.kv_info = self.KVInfo( + tp_size=self.world_size, + tp_rank=self.tp_rank, + remote_tp_size=self._tp_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + ) + @staticmethod def _nixl_handshake_listener( metadata: NixlAgentMetadata, @@ -704,8 +743,9 @@ class NixlConnectorWorker: # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio + p_remote_rank = self.kv_info.get_target_remote_rank( + remote_tp_size=remote_tp_size + ) path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug( "Querying metadata on path: %s at remote rank %s", path, p_remote_rank @@ -950,13 +990,11 @@ class NixlConnectorWorker: # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. - if self.vllm_config.model_config.hf_config.model_type == "llama4": + if self.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance( - self.vllm_config.model_config.hf_text_config, Llama4TextConfig - ) - llama4_config = self.vllm_config.model_config.hf_text_config + assert isinstance(self.model_config.hf_text_config, Llama4TextConfig) + llama4_config = self.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size chunk_block_size = math.ceil(chunk_size / self.block_size) @@ -1039,87 +1077,50 @@ class NixlConnectorWorker: engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): + logger.debug( + "Remote agent with engine_id %s and rank" + "%s already exchanged metadata, skip handshake.", + engine_id, + remote_tp_rank, + ) return self._remote_agents[engine_id][remote_tp_rank] + ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size - else: - assert self._tp_size[engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) - # Handle tp_size>num_kv_heads: replicate KV cache. - total_num_kv_heads = self.model_config.get_total_num_kv_heads() - is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id) - remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or is_kv_replicated: - # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) - else: - # When MLA is not used, this is a list of the same block length - for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, ( - "All remote layers must have the same block size" - ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - if tp_ratio > 1: - # Heterogeneous TP expects same kv_cache_layout. - assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout - if self.device_type == "xpu": - raise ValueError("Heterogeneous TP is not supported on XPU") - - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) - - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - - # Create dst descs and xfer side handles. TP workers have same #blocks. - if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - else: + # Create dst descs and xfer side handles. TP workers have same #blocks + # so we only register once per engine_id. + if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # Keep track of remote agent kv caches base addresses. + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self.kv_info.tp_ratio(engine_id) + + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len - if not (self.use_mla or is_kv_replicated) - else 0 + self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1154,6 +1155,64 @@ class NixlConnectorWorker: return remote_agent_name + def _validate_remote_agent_handshake( + self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + ): + """ + Validate the remote agent handshake metadata ensuring the + invariants hold true. + """ + remote_engine_id = nixl_agent_meta.engine_id + + assert self._tp_size[remote_engine_id] == remote_tp_size + # TODO We may eventually want to skip enforcing the same attn backend. + assert nixl_agent_meta.attn_backend_name == self.backend_name + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + + tp_ratio = self.kv_info.tp_ratio(remote_engine_id) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) + if tp_ratio > 1 and self.device_type == "xpu": + raise ValueError("Heterogeneous TP is not supported on XPU") + + # Block len can only vary across layers when using MLA. + remote_block_len = nixl_agent_meta.block_lens[0] + if self.use_mla or self.kv_info.is_kv_replicated(remote_engine_id): + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + else: + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) + + # TP workers have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer @@ -1383,14 +1442,14 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self.kv_info.tp_ratio(dst_engine_id) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.tp_rank // tp_ratio + remote_rank = self.kv_info.get_target_remote_rank(dst_engine_id) agent_name = self._remote_agents[dst_engine_id][remote_rank] self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) return