Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-08 12:56:04 +00:00
parent 1a1c81ca2f
commit 84dfd367a1

View File

@ -519,6 +519,10 @@ class NixlConnectorWorker:
if remote_tp_size is None:
assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id]
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
@ -1174,8 +1178,6 @@ class NixlConnectorWorker:
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
if tp_ratio > 1 and self.device_type == "xpu":
raise ValueError("Heterogeneous TP is not supported on XPU")
# Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
@ -1186,6 +1188,9 @@ class NixlConnectorWorker:
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else:
if tp_ratio > 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (