From 2ab27b70f5be145c36cc1a86e8ce4acba77aceb0 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 23 Sep 2025 22:32:57 +0800 Subject: [PATCH] [XPU] Fix MOE DP accuracy issue on XPU (#25465) Signed-off-by: yewentao256 --- examples/offline_inference/data_parallel.py | 11 ++++++++++- .../device_communicators/xpu_communicator.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 98fe36d0fb796..0076d4d30ee8e 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -101,6 +101,13 @@ def parse_args(): "--quantization", type=str, ) + parser.add_argument( + "--disable-expert-parallel", + dest="enable_expert_parallel", + action="store_false", + help="Disable expert parallel (default: enabled).", + ) + parser.set_defaults(enable_expert_parallel=True) return parser.parse_args() @@ -113,6 +120,7 @@ def main( dp_master_port, GPUs_per_dp_rank, enforce_eager, + enable_expert_parallel, trust_remote_code, max_num_seqs, max_model_len, @@ -168,7 +176,7 @@ def main( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=enforce_eager, - enable_expert_parallel=True, + enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, max_model_len=max_model_len, @@ -229,6 +237,7 @@ if __name__ == "__main__": dp_master_port, tp_size, args.enforce_eager, + args.enable_expert_parallel, args.trust_remote_code, args.max_num_seqs, args.max_model_len, diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 067315deb773d..b236bae261e03 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend != "naive": + logger.warning( + "`%s` all2all manager is not supported on XPU." + "Falling back to `naive` all2all manager for XPU.", + all2all_backend) + all2all_backend = "naive" if all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) @@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase): def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states