Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-10-08 12:56:04 +00:00
parent 1a1c81ca2f
commit 84dfd367a1

View File

@ -519,6 +519,10 @@ class NixlConnectorWorker:
if remote_tp_size is None: if remote_tp_size is None:
assert remote_engine_id is not None assert remote_engine_id is not None
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size return self.tp_size // remote_tp_size
def is_kv_replicated(self, remote_engine_id: EngineId) -> bool: def is_kv_replicated(self, remote_engine_id: EngineId) -> bool:
@ -1174,8 +1178,6 @@ class NixlConnectorWorker:
assert not self._use_pallas or tp_ratio == 1, ( assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
) )
if tp_ratio > 1 and self.device_type == "xpu":
raise ValueError("Heterogeneous TP is not supported on XPU")
# Block len can only vary across layers when using MLA. # Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0] remote_block_len = nixl_agent_meta.block_lens[0]
@ -1186,6 +1188,9 @@ class NixlConnectorWorker:
) )
remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
else: else:
if tp_ratio > 1 and self.device_type == "xpu":
# XPU uses NHD, hence it does not support splitting on H
raise ValueError("Heterogeneous TP is not supported on XPU")
# When MLA is not used, this is a list of the same block length # When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens: for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, ( assert block_len == remote_block_len, (