mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 12:25:01 +08:00
173 lines
4.9 KiB
Python
173 lines
4.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
|
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
|
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
|
|
from vllm.platforms import current_platform
|
|
|
|
if not current_platform.is_cpu():
|
|
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
|
|
|
EXPERT_NUM = [
|
|
8,
|
|
]
|
|
HIDDEN_DIM = [128, 2880]
|
|
INTERMEDIATE_DIM = [128, 2880]
|
|
BATCH_SIZE = [1, 64, 256]
|
|
ACT = ["silu", "swigluoai"]
|
|
USE_BIAS = [True, False]
|
|
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
DTYPE = [torch.bfloat16]
|
|
|
|
_CPU_MOE_ACT = {
|
|
"silu": SiluAndMul(),
|
|
"swigluoai": SwigluOAIAndMul(),
|
|
}
|
|
|
|
|
|
def ref_fused_moe(
|
|
input: torch.Tensor,
|
|
w13: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w13_bias: torch.Tensor | None,
|
|
w2_bias: torch.Tensor | None,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str,
|
|
) -> torch.Tensor:
|
|
len_experts = w13.size(0)
|
|
|
|
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
|
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
|
tokens_per_expert = cnts.sum(dim=0)
|
|
idxs = topk_ids.view(-1).argsort()
|
|
|
|
sorted_tokens = input[idxs // topk_ids.shape[1]]
|
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
|
|
|
outputs = []
|
|
start_idx = 0
|
|
|
|
for i, num_tokens in enumerate(tokens_per_expert):
|
|
end_idx = start_idx + num_tokens
|
|
if num_tokens == 0:
|
|
continue
|
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float()
|
|
curr_w13 = w13[i].float()
|
|
curr_w2 = w2[i].float()
|
|
|
|
curr_w13_bias = None
|
|
if w13_bias is not None:
|
|
curr_w13_bias = w13_bias[i].float()
|
|
|
|
curr_w2_bias = None
|
|
if w2_bias is not None:
|
|
curr_w2_bias = w2_bias[i].float()
|
|
|
|
gate_up = torch.nn.functional.linear(
|
|
tokens_for_this_expert, curr_w13, curr_w13_bias
|
|
)
|
|
# Note: to simulate the kernel implementation
|
|
gate_up = (
|
|
_CPU_MOE_ACT[activation]
|
|
.forward_native(gate_up)
|
|
.to(dtype=input.dtype)
|
|
.float()
|
|
)
|
|
expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias)
|
|
|
|
outputs.append(expert_out)
|
|
start_idx = end_idx
|
|
|
|
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
|
new_x = torch.empty_like(outs)
|
|
|
|
new_x[idxs] = outs
|
|
final_out = (
|
|
new_x.view(*topk_ids.shape, -1)
|
|
.mul_(topk_weights.unsqueeze(dim=-1))
|
|
.sum(dim=1)
|
|
.type(input.dtype)
|
|
)
|
|
return final_out
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
|
@pytest.mark.parametrize("expert_num", EXPERT_NUM)
|
|
@pytest.mark.parametrize("hidden_size", HIDDEN_DIM)
|
|
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM)
|
|
@pytest.mark.parametrize("use_bias", USE_BIAS)
|
|
@pytest.mark.parametrize("dtype", DTYPE)
|
|
@pytest.mark.parametrize("act", ACT)
|
|
@pytest.mark.parametrize("isa", ISA)
|
|
def test_cpu_fused_moe(
|
|
batch_size: int,
|
|
expert_num: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
use_bias: bool,
|
|
dtype: torch.dtype,
|
|
act: str,
|
|
isa: str,
|
|
):
|
|
current_platform.seed_everything(0)
|
|
|
|
topk_num = max(expert_num // 2, 1)
|
|
up_dim = 2 * intermediate_size
|
|
|
|
input = torch.randn((batch_size, hidden_size), dtype=dtype) / (
|
|
0.5 * hidden_size**0.5
|
|
)
|
|
w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / (
|
|
0.5 * hidden_size**0.5
|
|
)
|
|
w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / (
|
|
0.5 * intermediate_size**0.5
|
|
)
|
|
router_logits = torch.randn((batch_size, expert_num), dtype=dtype)
|
|
w13_bias = None
|
|
w2_bias = None
|
|
if use_bias:
|
|
w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5)
|
|
w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / (
|
|
0.5 * hidden_size**0.5
|
|
)
|
|
score = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk_num)
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
|
|
ref_output = ref_fused_moe(
|
|
input,
|
|
w13,
|
|
w2,
|
|
w13_bias,
|
|
w2_bias,
|
|
topk_weight,
|
|
topk_ids,
|
|
act,
|
|
)
|
|
|
|
packed_w13 = cpu_prepack_moe_weight(w13, isa)
|
|
packed_w2 = cpu_prepack_moe_weight(w2, isa)
|
|
output = cpu_fused_moe(
|
|
input,
|
|
packed_w13,
|
|
packed_w2,
|
|
w13_bias,
|
|
w2_bias,
|
|
topk_weight,
|
|
topk_ids,
|
|
act,
|
|
isa,
|
|
)
|
|
|
|
atol, rtol = get_default_atol(output), get_default_rtol(output)
|
|
(
|
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
|
f"{torch.max(torch.abs(output - ref_output))}",
|
|
)
|