[xpu]support moe models on XPU platform (#21643)

Signed-off-by: yan <yan.ma@intel.com>
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2025-08-02 22:49:08 +08:00 committed by GitHub
parent 4abfd8796f
commit 73e1b9b1d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -327,7 +327,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if current_platform.is_cpu():
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
)
elif current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
dtype = layer.w13_weight.dtype
@ -509,6 +516,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation,
)
def forward_xpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
raise NotImplementedError("Expert load balancing is not supported "
"for XPU.")
assert custom_routing_function is None
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
)
def forward_tpu(
self,
layer: torch.nn.Module,