Add bias handling to CPUFusedMOE kernel (#26289)

Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues@arm.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Sharif Inamdar <Sharif.Inamdar@arm.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Crefeda Rodrigues 2025-10-06 19:39:10 +01:00 committed by GitHub
parent b2ea5ba677
commit c02058c222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 2 deletions

View File

@ -909,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
@pytest.mark.parametrize("m", [1, 33])
@pytest.mark.parametrize("n,k", [(128, 128)])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
device = "cpu"
torch.manual_seed(7)
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
router_logits = torch.randn((m, e), device=device, dtype=dtype)
b1 = b2 = None
if with_bias:
b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
b2 = torch.randn((e, k), device=device, dtype=dtype) / 10
ref = (
torch_moe(a, w13, w2, router_logits, topk, b1, b2)
if with_bias
else torch_moe(a, w13, w2, router_logits, topk)
)
class _Dummy(torch.nn.Module):
def __init__(self, w13, w2, b1=None, b2=None):
super().__init__()
self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
if b1 is not None:
self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
if b2 is not None:
self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)
layer = _Dummy(w13, w2, b1, b2).to(dtype)
fused = CPUFusedMOE(layer)
out = fused(
layer=layer,
x=a,
use_grouped_topk=False,
top_k=topk,
router_logits=router_logits,
renormalize=False,
global_num_experts=e,
expert_map=None,
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_factor=1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation=activation,
)
# Tolerances: fp32 tight; bf16 looser (esp. with bias)
if dtype == torch.float32:
atol = 1e-3
elif with_bias:
atol = 8e-2
else:
atol = 5e-2
torch.testing.assert_close(out, ref, atol=atol, rtol=0)

View File

@ -276,6 +276,9 @@ class CPUFusedMOE:
outputs = []
start_idx = 0
has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias")
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
@ -283,11 +286,15 @@ class CPUFusedMOE:
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
layer_w2_weight = layer.w2_weight[i]
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = F.linear(
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
)
gate_up = silu_and_mul(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
outputs.append(expert_out)
start_idx = end_idx