mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 19:09:46 +08:00
Add bias handling to CPUFusedMOE kernel (#26289)
Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues@arm.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Sharif Inamdar <Sharif.Inamdar@arm.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
b2ea5ba677
commit
c02058c222
@ -909,3 +909,72 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
|||||||
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
|
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
|
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 33])
|
||||||
|
@pytest.mark.parametrize("n,k", [(128, 128)])
|
||||||
|
@pytest.mark.parametrize("e", [8])
|
||||||
|
@pytest.mark.parametrize("topk", [2])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("with_bias", [False, True])
|
||||||
|
@pytest.mark.parametrize("activation", ["silu"])
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
|
||||||
|
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
||||||
|
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
|
||||||
|
|
||||||
|
device = "cpu"
|
||||||
|
torch.manual_seed(7)
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||||
|
router_logits = torch.randn((m, e), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
b1 = b2 = None
|
||||||
|
if with_bias:
|
||||||
|
b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
|
||||||
|
b2 = torch.randn((e, k), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
ref = (
|
||||||
|
torch_moe(a, w13, w2, router_logits, topk, b1, b2)
|
||||||
|
if with_bias
|
||||||
|
else torch_moe(a, w13, w2, router_logits, topk)
|
||||||
|
)
|
||||||
|
|
||||||
|
class _Dummy(torch.nn.Module):
|
||||||
|
def __init__(self, w13, w2, b1=None, b2=None):
|
||||||
|
super().__init__()
|
||||||
|
self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
||||||
|
self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
||||||
|
if b1 is not None:
|
||||||
|
self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
|
||||||
|
if b2 is not None:
|
||||||
|
self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)
|
||||||
|
|
||||||
|
layer = _Dummy(w13, w2, b1, b2).to(dtype)
|
||||||
|
fused = CPUFusedMOE(layer)
|
||||||
|
out = fused(
|
||||||
|
layer=layer,
|
||||||
|
x=a,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
top_k=topk,
|
||||||
|
router_logits=router_logits,
|
||||||
|
renormalize=False,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
custom_routing_function=None,
|
||||||
|
scoring_func="softmax",
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
e_score_correction_bias=None,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tolerances: fp32 tight; bf16 looser (esp. with bias)
|
||||||
|
if dtype == torch.float32:
|
||||||
|
atol = 1e-3
|
||||||
|
elif with_bias:
|
||||||
|
atol = 8e-2
|
||||||
|
else:
|
||||||
|
atol = 5e-2
|
||||||
|
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
|
||||||
|
|||||||
@ -276,6 +276,9 @@ class CPUFusedMOE:
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
has_w13_bias = hasattr(layer, "w13_bias")
|
||||||
|
has_w2_bias = hasattr(layer, "w2_bias")
|
||||||
|
|
||||||
for i, num_tokens in enumerate(tokens_per_expert):
|
for i, num_tokens in enumerate(tokens_per_expert):
|
||||||
end_idx = start_idx + num_tokens
|
end_idx = start_idx + num_tokens
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
@ -283,11 +286,15 @@ class CPUFusedMOE:
|
|||||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||||
|
|
||||||
layer_w13_weight = layer.w13_weight[i]
|
layer_w13_weight = layer.w13_weight[i]
|
||||||
|
layer_w13_bias = layer.w13_bias[i] if has_w13_bias else None
|
||||||
layer_w2_weight = layer.w2_weight[i]
|
layer_w2_weight = layer.w2_weight[i]
|
||||||
|
layer_w2_bias = layer.w2_bias[i] if has_w2_bias else None
|
||||||
|
|
||||||
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
gate_up = F.linear(
|
||||||
|
tokens_for_this_expert, layer_w13_weight, bias=layer_w13_bias
|
||||||
|
)
|
||||||
gate_up = silu_and_mul(gate_up)
|
gate_up = silu_and_mul(gate_up)
|
||||||
expert_out = F.linear(gate_up, layer_w2_weight)
|
expert_out = F.linear(gate_up, layer_w2_weight, bias=layer_w2_bias)
|
||||||
outputs.append(expert_out)
|
outputs.append(expert_out)
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user