diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 9354e819877a6..f357d149bd071 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -909,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) opcheck(torch.ops._moe_C.moe_sum, (input, actual)) + + +@pytest.mark.parametrize("m", [1, 33]) +@pytest.mark.parametrize("n,k", [(128, 128)]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("with_bias", [False, True]) +@pytest.mark.parametrize("activation", ["silu"]) +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test") +def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation): + from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE + + device = "cpu" + torch.manual_seed(7) + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + router_logits = torch.randn((m, e), device=device, dtype=dtype) + + b1 = b2 = None + if with_bias: + b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10 + b2 = torch.randn((e, k), device=device, dtype=dtype) / 10 + + ref = ( + torch_moe(a, w13, w2, router_logits, topk, b1, b2) + if with_bias + else torch_moe(a, w13, w2, router_logits, topk) + ) + + class _Dummy(torch.nn.Module): + def __init__(self, w13, w2, b1=None, b2=None): + super().__init__() + self.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + self.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + if b1 is not None: + self.w13_bias = torch.nn.Parameter(b1, requires_grad=False) + if b2 is not None: + self.w2_bias = torch.nn.Parameter(b2, requires_grad=False) + + layer = _Dummy(w13, w2, b1, b2).to(dtype) + fused = CPUFusedMOE(layer) + out = fused( + layer=layer, + x=a, + use_grouped_topk=False, + top_k=topk, + router_logits=router_logits, + renormalize=False, + global_num_experts=e, + expert_map=None, + custom_routing_function=None, + scoring_func="softmax", + routed_scaling_factor=1.0, + e_score_correction_bias=None, + apply_router_weight_on_input=False, + activation=activation, + ) + + # Tolerances: fp32 tight; bf16 looser (esp. with bias) + if dtype == torch.float32: + atol = 1e-3 + elif with_bias: + atol = 8e-2 + else: + atol = 5e-2 + torch.testing.assert_close(out, ref, atol=atol, rtol=0) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index b62817d0115f8..2ee91637a834b 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -276,6 +276,9 @@ class CPUFusedMOE: outputs = [] start_idx = 0 + has_w13_bias = hasattr(layer, "w13_bias") + has_w2_bias = hasattr(layer, "w2_bias") + for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: @@ -283,11 +286,15 @@ class CPUFusedMOE: tokens_for_this_expert = sorted_tokens[start_idx:end_idx] layer_w13_weight = layer.w13_weight[i] + layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None layer_w2_weight = layer.w2_weight[i] + layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None - gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = F.linear( + tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias + ) gate_up = silu_and_mul(gate_up) - expert_out = F.linear(gate_up, layer_w2_weight) + expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias) outputs.append(expert_out) start_idx = end_idx