From 5d45b77124a1e8e464274a1e81d4c3106f02597a Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 8 Oct 2025 14:12:35 +0000 Subject: [PATCH] docs Signed-off-by: NickLucche --- .../kv_connector/v1/nixl_connector.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ca04f5565411c..ca53d1df92aec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -516,6 +516,12 @@ class NixlConnectorWorker: remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ if remote_tp_size is None: assert remote_engine_id is not None remote_tp_size = self.remote_tp_size[remote_engine_id] @@ -525,11 +531,16 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size - def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: - tp_size = self.remote_tp_size[remote_engine_id] + def is_kv_replicated(self, engine_id: EngineId) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] return tp_size // self.total_num_kv_heads >= 1 def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + # MLA is always replicated as the hidden dim can't be split. return self.is_mla or self.is_kv_replicated(remote_engine_id) def get_target_remote_rank( @@ -537,6 +548,10 @@ class NixlConnectorWorker: remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None, ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size) return self.tp_rank // tp_ratio