Add SwigluOAI implementation for CPUFusedMOE (#26347)

Signed-off-by: Sharif Inamdar <sharif.inamdar@arm.com>
This commit is contained in:
isharif168 2025-10-08 03:17:49 +01:00 committed by GitHub
parent b32260ab85
commit 046118b938
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,17 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]
def swigluoai_and_mul(
x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x[..., :d], x[..., d:]
gate = gate.clamp(max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(alpha * gate)
return (up + 1) * glu
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@ -247,7 +258,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", f"{activation} is not supported."
assert activation in {"silu", "swigluoai"}, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
@ -293,7 +304,10 @@ class CPUFusedMOE:
gate_up = F.linear(
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
)
gate_up = silu_and_mul(gate_up)
if activation == "swigluoai":
gate_up = swigluoai_and_mul(gate_up)
else:
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
outputs.append(expert_out)
start_idx = end_idx