diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 869e80a1af88..a80abf10a8a1 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -565,8 +565,6 @@ class TestNixlHandshake: kv_cache_layout=mismatched_layout, ) - # We don't check layout for homogeneous TP and MLA for now, as the - # whole block is moved. with pytest.raises(RuntimeError): # mismatched layout is expected to fail worker.add_remote_agent(meta, remote_tp_size=2) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6d80667788d6..ede5db646f3d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, get_tp_group, ) -from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform @@ -521,6 +520,72 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" + @dataclass + class TpKVTopology: + """ + Helper class for tensor parallel and KV topology information for + mapping between local and remote TP workers. + """ + + tp_size: int + tp_rank: int + remote_tp_size: dict[EngineId, int] + is_mla: bool + total_num_kv_heads: int + + def tp_ratio( + self, + remote_tp_size: int, + ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + + def tp_ratio_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_ratio(remote_tp_size) + + def is_kv_replicated(self, engine_id: EngineId) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + # MLA is always replicated as the hidden dim can't be split. + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_tp_size: int, + ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ + tp_ratio = self.tp_ratio(remote_tp_size) + return self.tp_rank // tp_ratio + + def get_target_remote_rank_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.get_target_remote_rank(remote_tp_size) + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -534,6 +599,7 @@ class NixlConnectorWorker: if vllm_config.kv_transfer_config is None: raise ValueError("kv_transfer_config must be set for NixlConnector") + self.kv_transfer_config = vllm_config.kv_transfer_config self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] @@ -654,7 +720,6 @@ class NixlConnectorWorker: # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() - self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -686,6 +751,14 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() + self.kv_topo = self.TpKVTopology( + tp_size=self.world_size, + tp_rank=self.tp_rank, + remote_tp_size=self._tp_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + ) + @staticmethod def _nixl_handshake_listener( metadata: NixlAgentMetadata, @@ -731,8 +804,7 @@ class NixlConnectorWorker: # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio + p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug( "Querying metadata on path: %s at remote rank %s", path, p_remote_rank @@ -989,13 +1061,11 @@ class NixlConnectorWorker: # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. - if self.vllm_config.model_config.hf_config.model_type == "llama4": + if self.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance( - self.vllm_config.model_config.hf_text_config, Llama4TextConfig - ) - llama4_config = self.vllm_config.model_config.hf_text_config + assert isinstance(self.model_config.hf_text_config, Llama4TextConfig) + llama4_config = self.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size chunk_block_size = math.ceil(chunk_size / self.block_size) @@ -1078,107 +1148,51 @@ class NixlConnectorWorker: engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): + logger.debug( + "Remote agent with engine_id %s and rank" + "%s already exchanged metadata, skip handshake.", + engine_id, + remote_tp_rank, + ) return self._remote_agents[engine_id][remote_tp_rank] + ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size - else: - assert self._tp_size[engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) - # Handle tp_size>num_kv_heads: replicate KV cache. - total_num_kv_heads = self.model_config.get_total_num_kv_heads() - is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) - remote_block_len = nixl_agent_meta.block_lens[0] - if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: - if ( - self.vllm_config.kv_transfer_config is not None - and self.vllm_config.kv_transfer_config.enable_permute_local_kv - and nixl_agent_meta.kv_cache_layout == "HND" - ): - logger.info( - "Remote is HND and local is NHD, enabled additional permute " - "on local device KV." - ) - self.enable_permute_local_kv = True - else: - raise RuntimeError( - "Heterogeneous TP expects same kv_cache_layout. " - "Or enable experimental feature to use HND to NHD support by " - "setting 'enable_permute_local_kv'=True in --kv-transfer-config." - ) - if self.use_mla or is_kv_replicated: - # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) - else: - # When MLA is not used, this is a list of the same block length - for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, ( - "All remote layers must have the same block size" - ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - if tp_ratio > 1: - # Heterogeneous TP expects same kv_cache_layout. - if nixl_agent_meta.kv_cache_layout == "NHD": - raise ValueError( - "Heterogeneous TP is not supported for remote with NHD." - ) - if self.device_type == "xpu": - raise ValueError("Heterogeneous TP is not supported on XPU") - - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) - - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - - # Create dst descs and xfer side handles. TP workers have same #blocks. - if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - else: + # Create dst descs and xfer side handles. TP workers have same #blocks + # so we only register once per engine_id. + if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # Keep track of remote agent kv caches base addresses. + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr + + self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) + + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len - if not (self.use_mla or is_kv_replicated) - else 0 + self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1213,6 +1227,80 @@ class NixlConnectorWorker: return remote_agent_name + def _validate_remote_agent_handshake( + self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + ): + """ + Validate the remote agent handshake metadata ensuring the + invariants hold true. + """ + remote_engine_id = nixl_agent_meta.engine_id + + assert self._tp_size[remote_engine_id] == remote_tp_size + # TODO We may eventually want to skip enforcing the same attn backend. + assert nixl_agent_meta.attn_backend_name == self.backend_name + + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) + if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: + if ( + self.kv_transfer_config.enable_permute_local_kv + and nixl_agent_meta.kv_cache_layout == "HND" + ): + logger.info( + "Remote is HND and local is NHD, enabled additional permute " + "on local device KV." + ) + self.enable_permute_local_kv = True + else: + raise RuntimeError( + "Heterogeneous TP expects same kv_cache_layout. " + "Or enable experimental feature to use HND to NHD support by " + "setting 'enable_permute_local_kv'=True in --kv-transfer-config." + ) + + # Block len can only vary across layers when using MLA. + remote_block_len = nixl_agent_meta.block_lens[0] + if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + else: + if tp_ratio > 1 and self.device_type == "xpu": + # XPU uses NHD, hence it does not support splitting on H + raise ValueError("Heterogeneous TP is not supported on XPU") + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) + + # TP workers have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer @@ -1505,14 +1593,16 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.tp_rank // tp_ratio + remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id( + dst_engine_id + ) agent_name = self._remote_agents[dst_engine_id][remote_rank] try: self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)