vllm/tests/kernels/moe/test_moe.py
Crefeda Rodrigues c02058c222
Add bias handling to CPUFusedMOE kernel (#26289)
Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues@arm.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Sharif Inamdar <Sharif.Inamdar@arm.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2025-10-06 18:39:10 +00:00

981 lines
31 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union
import pytest
import torch
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize,
marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
(33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511),
(40000, 1024, 1024),
]
FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
(32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024),
]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def run_moe_test(
baseline: Union[Callable, torch.Tensor],
moe_fn: Callable,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
padding: bool = False,
use_compile: bool = False,
use_cudagraph: bool = False,
atol: float = 2e-2,
rtol: float = 0,
) -> torch.Tensor:
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
if use_compile:
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
return baseline_output
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
padding: bool,
chunk_size: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
#
# Setup test data
#
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
#
# Setup test functions
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
def m_fused_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(
a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
#
# Run tests
#
runner = functools.partial(
run_moe_test,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
padding=padding,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
with set_current_vllm_config(vllm_config):
baseline_output = runner(torch_moe, iterative_moe)
runner(
baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
runner(
baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
has_zp: bool,
weight_bits: int,
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if weight_bits == 4:
pack_factor = 2
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
elif weight_bits == 8:
pack_factor = 1
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty(
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = (
w1,
w1_ref,
w1_qweight,
w1_scales,
w1_qzeros,
)
else:
w, w_ref, w_qweight, w_scales, w_qzeros = (
w2,
w2_ref,
w2_qweight,
w2_scales,
w2_qzeros,
)
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False
)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
w_ref[expert_id] = weight
w_qweight[expert_id] = qweight
w_scales[expert_id] = scales
if has_zp:
w_qzeros[expert_id] = qzeros
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1_ref = w1_ref[e_ids]
w2_ref = w2_ref[e_ids]
w1_qweight = w1_qweight[e_ids]
w2_qweight = w2_qweight[e_ids]
w1_scales = w1_scales[e_ids]
w2_scales = w2_scales[e_ids]
w1_qzeros = w1_qzeros[e_ids]
w2_qzeros = w2_qzeros[e_ids]
else:
e_map = None
if weight_bits == 4:
quant_config_builder = int4_w4a16_moe_quant_config
else:
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size],
)
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
global_num_experts=e,
expert_map=e_map,
quant_config=quant_config,
)
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@torch.inference_mode()
def test_mixtral_moe(
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv("RANK", "0")
monkeypatch.setenv("LOCAL_RANK", "0")
monkeypatch.setenv("WORLD_SIZE", "1")
monkeypatch.setenv("MASTER_ADDR", "localhost")
monkeypatch.setenv("MASTER_PORT", "12345")
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
dp_size=1,
).cuda()
# Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (
hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data,
)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
..., 0:-128
],
requires_grad=False,
)
vllm_moe.experts.w2_weight = Parameter(
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False,
)
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)
mixtral_moe_tol = {
torch.float32: 1e-3,
torch.float16: 1e-3,
torch.bfloat16: 1e-2,
}
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package.
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174
torch.testing.assert_close(
hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
)
else:
torch.testing.assert_close(
hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype],
)
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [4, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [-1, 16, 32, 128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
return False
return True
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
return cases
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases(),
)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
quant_type: ScalarType,
is_k_full: bool,
):
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
global_scale1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = (
rand_marlin_weight_nvfp4_like(w1[i], group_size)
)
else:
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
w1[i], group_size
)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
global_scale2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = (
rand_marlin_weight_nvfp4_like(w2[i], group_size)
)
else:
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
w2[i], group_size
)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
None,
None,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0)
e, topk = 32, 4
n, k = 2048, 2048
group_size = 128
act_order = False
is_k_full = True
quant_type = scalar_types.uint4b8
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
opcheck(
torch.ops._moe_C.moe_align_block_size,
(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
torch.ops._moe_C.moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
@pytest.mark.parametrize("m", [1, 33])
@pytest.mark.parametrize("n,k", [(128, 128)])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
device = "cpu"
torch.manual_seed(7)
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
router_logits = torch.randn((m, e), device=device, dtype=dtype)
b1 = b2 = None
if with_bias:
b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
b2 = torch.randn((e, k), device=device, dtype=dtype) / 10
ref = (
torch_moe(a, w13, w2, router_logits, topk, b1, b2)
if with_bias
else torch_moe(a, w13, w2, router_logits, topk)
)
class _Dummy(torch.nn.Module):
def __init__(self, w13, w2, b1=None, b2=None):
super().__init__()
self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
if b1 is not None:
self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
if b2 is not None:
self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)
layer = _Dummy(w13, w2, b1, b2).to(dtype)
fused = CPUFusedMOE(layer)
out = fused(
layer=layer,
x=a,
use_grouped_topk=False,
top_k=topk,
router_logits=router_logits,
renormalize=False,
global_num_experts=e,
expert_map=None,
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_factor=1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation=activation,
)
# Tolerances: fp32 tight; bf16 looser (esp. with bias)
if dtype == torch.float32:
atol = 1e-3
elif with_bias:
atol = 8e-2
else:
atol = 5e-2
torch.testing.assert_close(out, ref, atol=atol, rtol=0)