Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-08 14:12:35 +00:00
parent 84dfd367a1
commit 5d45b77124

View File

@ -516,6 +516,12 @@ class NixlConnectorWorker:
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
if remote_tp_size is None:
assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id]
@ -525,11 +531,16 @@ class NixlConnectorWorker:
)
return self.tp_size // remote_tp_size
def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
tp_size = self.remote_tp_size[remote_engine_id]
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
@ -537,6 +548,10 @@ class NixlConnectorWorker:
remote_engine_id: Optional[EngineId] = None,
remote_tp_size: Optional[int] = None,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
return self.tp_rank // tp_ratio