mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1138 lines
36 KiB
Python
1138 lines
36 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the MOE layers.
|
|
|
|
Run `pytest tests/kernels/test_moe.py`.
|
|
"""
|
|
|
|
import functools
|
|
import importlib
|
|
import sys
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.nn import Parameter
|
|
from torch.nn import functional as F
|
|
from transformers import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
|
|
import vllm.model_executor.layers.fused_moe # noqa
|
|
from tests.kernels.moe.utils import fused_moe
|
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
from vllm.distributed.parallel_state import init_distributed_environment
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
|
int4_w4a16_moe_quant_config,
|
|
int8_w8a16_moe_quant_config,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
batched_fused_marlin_moe,
|
|
fused_marlin_moe,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
fused_topk,
|
|
modular_triton_fused_moe,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
|
fused_moe as iterative_moe,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
marlin_permute_bias,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|
rand_marlin_weight_mxfp4_like,
|
|
rand_marlin_weight_nvfp4_like,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
marlin_quant_fp8_torch,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|
awq_marlin_quantize,
|
|
marlin_quantize,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
|
from vllm.platforms import current_platform
|
|
from vllm.scalar_type import ScalarType, scalar_types
|
|
|
|
NUM_EXPERTS = [8, 64, 192]
|
|
EP_SIZE = [1, 4]
|
|
TOP_KS = [2, 6]
|
|
|
|
FUSED_MOE_MNK_FACTORS = [
|
|
(1, 128, 128),
|
|
(1, 2048, 128),
|
|
(33, 2048, 128),
|
|
(32768, 2048, 511),
|
|
(40000, 1024, 1024),
|
|
]
|
|
|
|
FUSED_MOE_WN16_MNK_FACTORS = [
|
|
(1, 128, 128),
|
|
(1, 1024, 1024),
|
|
(32, 2048, 128),
|
|
(222, 2048, 1024),
|
|
]
|
|
|
|
vllm_config = VllmConfig()
|
|
vllm_config.scheduler_config.max_num_seqs = 128
|
|
vllm_config.scheduler_config.max_model_len = 8192
|
|
|
|
|
|
def run_moe_test(
|
|
baseline: Callable | torch.Tensor,
|
|
moe_fn: Callable,
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
score: torch.Tensor,
|
|
topk: int,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
padding: bool = False,
|
|
use_compile: bool = False,
|
|
use_cudagraph: bool = False,
|
|
atol: float = 2e-2,
|
|
rtol: float = 0,
|
|
) -> torch.Tensor:
|
|
if isinstance(baseline, torch.Tensor):
|
|
baseline_output = baseline
|
|
else:
|
|
baseline_output = baseline(
|
|
a,
|
|
w1,
|
|
w2,
|
|
score,
|
|
topk,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
# Pad the weight if moe padding is enabled
|
|
if padding:
|
|
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
|
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
|
|
|
if use_compile:
|
|
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
|
|
torch._dynamo.mark_dynamic(a, 0)
|
|
torch._dynamo.mark_dynamic(score, 0)
|
|
|
|
test_output = moe_fn(
|
|
a,
|
|
w1,
|
|
w2,
|
|
score,
|
|
topk,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
if use_cudagraph:
|
|
test_output.fill_(0)
|
|
stream = torch.cuda.Stream()
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
test_output = moe_fn(
|
|
a,
|
|
w1,
|
|
w2,
|
|
score,
|
|
topk,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
torch.cuda.synchronize()
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
|
|
|
|
return baseline_output
|
|
|
|
|
|
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("padding", [True, False])
|
|
@pytest.mark.parametrize("chunk_size", [8192])
|
|
def test_fused_moe(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
ep_size: int,
|
|
dtype: torch.dtype,
|
|
padding: bool,
|
|
chunk_size: int,
|
|
monkeypatch,
|
|
):
|
|
current_platform.seed_everything(7)
|
|
|
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
|
|
|
#
|
|
# Setup test data
|
|
#
|
|
|
|
#
|
|
# Setup test data
|
|
#
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
if ep_size > 1:
|
|
local_e = e // ep_size
|
|
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
|
|
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
|
w1 = w1[e_ids]
|
|
w2 = w2[e_ids]
|
|
else:
|
|
e_map = None
|
|
|
|
#
|
|
# Setup test functions
|
|
#
|
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
|
|
|
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
|
|
|
|
def m_fused_moe(
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
score: torch.Tensor,
|
|
topk: int,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
return m_fused_moe_fn(
|
|
a,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
|
|
|
#
|
|
# Run tests
|
|
#
|
|
runner = functools.partial(
|
|
run_moe_test,
|
|
a=a,
|
|
w1=w1,
|
|
w2=w2,
|
|
score=score,
|
|
topk=topk,
|
|
global_num_experts=e,
|
|
expert_map=e_map,
|
|
padding=padding,
|
|
)
|
|
|
|
# Note: for now use_compile will error out if the problem size is
|
|
# large enough to trigger chunking. I'm leaving the flag and
|
|
# setup code in case we are able to revisit this later.
|
|
use_compile = False
|
|
|
|
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
baseline_output = runner(torch_moe, iterative_moe)
|
|
runner(
|
|
baseline_output,
|
|
fused_moe_fn,
|
|
use_compile=use_compile,
|
|
use_cudagraph=use_cudagraph,
|
|
)
|
|
runner(
|
|
baseline_output,
|
|
m_fused_moe,
|
|
use_compile=use_compile,
|
|
use_cudagraph=use_cudagraph,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("group_size", [64, 128])
|
|
@pytest.mark.parametrize("has_zp", [True, False])
|
|
@pytest.mark.parametrize("weight_bits", [4, 8])
|
|
def test_fused_moe_wn16(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
ep_size: int,
|
|
dtype: torch.dtype,
|
|
group_size: int,
|
|
has_zp: bool,
|
|
weight_bits: int,
|
|
):
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
if weight_bits == 4:
|
|
pack_factor = 2
|
|
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
|
|
elif weight_bits == 8:
|
|
pack_factor = 1
|
|
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
|
|
|
|
w1_ref = w1.clone()
|
|
w2_ref = w2.clone()
|
|
w1_qweight = torch.empty(
|
|
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
|
|
)
|
|
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
|
|
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
|
|
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
|
|
w1_qzeros = torch.empty(
|
|
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
|
|
)
|
|
w2_qzeros = torch.empty(
|
|
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
|
|
)
|
|
|
|
for i in range(e * 2):
|
|
expert_id = i % e
|
|
if i // e == 0:
|
|
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
|
w1,
|
|
w1_ref,
|
|
w1_qweight,
|
|
w1_scales,
|
|
w1_qzeros,
|
|
)
|
|
else:
|
|
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
|
w2,
|
|
w2_ref,
|
|
w2_qweight,
|
|
w2_scales,
|
|
w2_qzeros,
|
|
)
|
|
weight, qweight, scales, qzeros = quantize_weights(
|
|
w[expert_id].T, quant_type, group_size, has_zp, False
|
|
)
|
|
weight = weight.T
|
|
qweight = qweight.T.contiguous().to(torch.uint8)
|
|
scales = scales.T
|
|
if has_zp:
|
|
qzeros = qzeros.T.contiguous().to(torch.uint8)
|
|
if weight_bits == 4:
|
|
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
|
|
if has_zp:
|
|
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
|
|
|
|
w_ref[expert_id] = weight
|
|
w_qweight[expert_id] = qweight
|
|
w_scales[expert_id] = scales
|
|
if has_zp:
|
|
w_qzeros[expert_id] = qzeros
|
|
|
|
if ep_size > 1:
|
|
local_e = e // ep_size
|
|
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
|
|
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
|
w1_ref = w1_ref[e_ids]
|
|
w2_ref = w2_ref[e_ids]
|
|
w1_qweight = w1_qweight[e_ids]
|
|
w2_qweight = w2_qweight[e_ids]
|
|
w1_scales = w1_scales[e_ids]
|
|
w2_scales = w2_scales[e_ids]
|
|
w1_qzeros = w1_qzeros[e_ids]
|
|
w2_qzeros = w2_qzeros[e_ids]
|
|
else:
|
|
e_map = None
|
|
|
|
if weight_bits == 4:
|
|
quant_config_builder = int4_w4a16_moe_quant_config
|
|
else:
|
|
assert weight_bits == 8
|
|
quant_config_builder = int8_w8a16_moe_quant_config
|
|
|
|
quant_config = quant_config_builder(
|
|
w1_scale=w1_scales,
|
|
w2_scale=w2_scales,
|
|
w1_zp=w1_qzeros if has_zp else None,
|
|
w2_zp=w2_qzeros if has_zp else None,
|
|
block_shape=[0, group_size],
|
|
)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
triton_output = fused_moe(
|
|
a,
|
|
w1_qweight,
|
|
w2_qweight,
|
|
score,
|
|
topk,
|
|
renormalize=False,
|
|
global_num_experts=e,
|
|
expert_map=e_map,
|
|
quant_config=quant_config,
|
|
)
|
|
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
|
|
|
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("padding", [True, False])
|
|
@pytest.mark.parametrize(
|
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
|
)
|
|
@torch.inference_mode()
|
|
def test_mixtral_moe(
|
|
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
|
|
):
|
|
"""Make sure our Mixtral MoE implementation agrees with the one from
|
|
huggingface."""
|
|
|
|
# clear the cache before every test
|
|
# Force reload aiter_ops to pick up the new environment variables.
|
|
if "rocm_aiter_ops" in sys.modules:
|
|
importlib.reload(rocm_aiter_ops)
|
|
|
|
if use_rocm_aiter:
|
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
|
if dtype == torch.float32:
|
|
pytest.skip("AITER ROCm test skip for float32")
|
|
|
|
monkeypatch.setenv("RANK", "0")
|
|
monkeypatch.setenv("LOCAL_RANK", "0")
|
|
monkeypatch.setenv("WORLD_SIZE", "1")
|
|
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
|
monkeypatch.setenv("MASTER_PORT", "12345")
|
|
init_distributed_environment()
|
|
|
|
# Instantiate our and huggingface's MoE blocks
|
|
vllm_config.compilation_config.static_forward_context = dict()
|
|
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
|
|
config = MixtralConfig()
|
|
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
|
vllm_moe = MixtralMoE(
|
|
num_experts=config.num_local_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
params_dtype=dtype,
|
|
tp_size=1,
|
|
dp_size=1,
|
|
).cuda()
|
|
|
|
# Load the weights
|
|
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
|
for i in range(config.num_local_experts):
|
|
weights = (
|
|
hf_moe.experts[i].w1.weight.data,
|
|
hf_moe.experts[i].w3.weight.data,
|
|
)
|
|
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
|
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
|
|
|
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
|
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
|
# vLLM uses 1D query [num_tokens, hidden_dim]
|
|
vllm_inputs = hf_inputs.flatten(0, 1)
|
|
|
|
# Pad the weight if moe padding is enabled
|
|
if padding:
|
|
vllm_moe.experts.w13_weight = Parameter(
|
|
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
|
|
..., 0:-128
|
|
],
|
|
requires_grad=False,
|
|
)
|
|
vllm_moe.experts.w2_weight = Parameter(
|
|
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
|
requires_grad=False,
|
|
)
|
|
torch.cuda.synchronize()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Run forward passes for both MoE blocks
|
|
hf_states, _ = hf_moe.forward(hf_inputs)
|
|
vllm_states = vllm_moe.forward(vllm_inputs)
|
|
|
|
mixtral_moe_tol = {
|
|
torch.float32: 1e-3,
|
|
torch.float16: 1e-3,
|
|
torch.bfloat16: 1e-2,
|
|
}
|
|
|
|
if use_rocm_aiter:
|
|
# The values of rtol and atol are set based on the tests in ROCM AITER package.
|
|
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174
|
|
torch.testing.assert_close(
|
|
hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
|
|
)
|
|
else:
|
|
torch.testing.assert_close(
|
|
hf_states.flatten(0, 1),
|
|
vllm_states,
|
|
rtol=mixtral_moe_tol[dtype],
|
|
atol=mixtral_moe_tol[dtype],
|
|
)
|
|
|
|
|
|
def marlin_moe_generate_valid_test_cases():
|
|
import itertools
|
|
|
|
m_list = [1, 123, 666]
|
|
n_list = [128, 1024]
|
|
k_list = [256, 2048]
|
|
e_list = [4, 12]
|
|
topk_list = [2, 3]
|
|
ep_size_list = [1, 4]
|
|
dtype_list = [torch.bfloat16]
|
|
group_size_list = [-1, 32, 128]
|
|
act_order_list = [True, False]
|
|
quant_type_list = [
|
|
scalar_types.float4_e2m1f,
|
|
scalar_types.float8_e4m3fn,
|
|
scalar_types.uint4,
|
|
scalar_types.uint4b8,
|
|
scalar_types.uint8b128,
|
|
]
|
|
is_k_full_list = [True, False]
|
|
|
|
all_combinations = itertools.product(
|
|
m_list,
|
|
n_list,
|
|
k_list,
|
|
e_list,
|
|
topk_list,
|
|
ep_size_list,
|
|
dtype_list,
|
|
group_size_list,
|
|
act_order_list,
|
|
quant_type_list,
|
|
is_k_full_list,
|
|
)
|
|
|
|
def is_invalid(
|
|
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
|
|
):
|
|
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
|
|
return False
|
|
if quant_type == scalar_types.float4_e2m1f:
|
|
if group_size not in [16, 32]:
|
|
return False
|
|
if dtype == torch.float16 and group_size == 32:
|
|
return False
|
|
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
|
return False
|
|
|
|
# Filter act_order
|
|
if act_order:
|
|
if group_size in (-1, k, n):
|
|
return False
|
|
if quant_type not in [scalar_types.uint4b8]:
|
|
return False
|
|
elif not is_k_full:
|
|
return False
|
|
|
|
return True
|
|
|
|
cases = []
|
|
for case in all_combinations:
|
|
if is_invalid(*case):
|
|
cases.append(case)
|
|
return cases
|
|
|
|
|
|
@dataclass
|
|
class MarlinMoEWeightData:
|
|
w_ref: torch.Tensor
|
|
qweight: torch.Tensor
|
|
scales: torch.Tensor
|
|
global_scale: torch.Tensor | None
|
|
g_idx: torch.Tensor | None
|
|
zeros: torch.Tensor | None
|
|
sort_indices: torch.Tensor | None
|
|
marlin_bias: torch.Tensor | None
|
|
|
|
@staticmethod
|
|
def make(
|
|
w: torch.Tensor,
|
|
quant_type: ScalarType,
|
|
group_size: int,
|
|
act_order: bool | None = None,
|
|
bias: torch.Tensor | None = None,
|
|
) -> "MarlinMoEWeightData":
|
|
assert w.ndim == 3
|
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
|
k = w.shape[-1]
|
|
|
|
w_ref_l: list[torch.Tensor] = []
|
|
qweight_l: list[torch.Tensor] = []
|
|
scales_l: list[torch.Tensor] = []
|
|
global_scale_l: list[torch.Tensor] = []
|
|
zeros_l: list[torch.Tensor] = []
|
|
g_idx_l: list[torch.Tensor] = []
|
|
sort_indices_l: list[torch.Tensor] = []
|
|
bias_l: list[torch.Tensor] = []
|
|
|
|
for i in range(w.shape[0]):
|
|
if quant_type == scalar_types.float4_e2m1f:
|
|
if group_size == 16:
|
|
w_ref, qweight, scales, global_scale = (
|
|
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
|
)
|
|
else:
|
|
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
|
w[i], group_size
|
|
)
|
|
global_scale = None
|
|
|
|
w_ref_l.append(w_ref.T)
|
|
qweight_l.append(qweight)
|
|
scales_l.append(scales)
|
|
if global_scale is not None:
|
|
global_scale_l.append(global_scale)
|
|
elif quant_type == scalar_types.float8_e4m3fn:
|
|
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
|
w_ref_l.append(w_ref.T)
|
|
qweight_l.append(qweight)
|
|
scales_l.append(scales)
|
|
elif has_zp:
|
|
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
|
w[i].transpose(1, 0), quant_type, group_size
|
|
)
|
|
|
|
w_ref_l.append(w_ref.T)
|
|
qweight_l.append(qweight)
|
|
scales_l.append(scales)
|
|
zeros_l.append(zeros)
|
|
else:
|
|
test_perm = torch.randperm(k)
|
|
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
|
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
|
)
|
|
|
|
w_ref_l.append(w_ref.T)
|
|
qweight_l.append(qweight)
|
|
scales_l.append(scales)
|
|
g_idx_l.append(g_idx)
|
|
sort_indices_l.append(sort_indices)
|
|
|
|
if bias is not None:
|
|
bias_l.append(marlin_permute_bias(bias[i]))
|
|
|
|
w_ref = stack_and_dev(w_ref_l)
|
|
qweight = stack_and_dev(qweight_l).contiguous()
|
|
scales = stack_and_dev(scales_l)
|
|
global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
|
|
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
|
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
|
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
|
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
|
|
|
return MarlinMoEWeightData(
|
|
w_ref=w_ref,
|
|
qweight=qweight,
|
|
scales=scales,
|
|
global_scale=global_scale,
|
|
g_idx=g_idx,
|
|
zeros=zeros,
|
|
sort_indices=sort_indices,
|
|
marlin_bias=marlin_bias,
|
|
)
|
|
|
|
|
|
@pytest.mark.flaky(reruns=2)
|
|
@pytest.mark.parametrize(
|
|
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
|
marlin_moe_generate_valid_test_cases(),
|
|
)
|
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
|
def test_fused_marlin_moe(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
ep_size: int,
|
|
dtype: torch.dtype,
|
|
group_size: int,
|
|
act_order: bool,
|
|
quant_type: ScalarType,
|
|
is_k_full: bool,
|
|
):
|
|
torch.cuda.manual_seed(0)
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
|
|
|
if ep_size > 1:
|
|
local_e = e // ep_size
|
|
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
|
|
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
|
w1 = w1[e_ids]
|
|
w2 = w2[e_ids]
|
|
else:
|
|
e_map = None
|
|
|
|
w1_data = MarlinMoEWeightData.make(
|
|
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
|
)
|
|
|
|
w2_data = MarlinMoEWeightData.make(
|
|
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
|
)
|
|
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
torch_output = torch_moe(
|
|
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
|
)
|
|
|
|
marlin_output = fused_marlin_moe(
|
|
a,
|
|
w1_data.qweight,
|
|
w2_data.qweight,
|
|
None,
|
|
None,
|
|
w1_data.scales,
|
|
w2_data.scales,
|
|
score,
|
|
topk_weights,
|
|
topk_ids,
|
|
global_num_experts=e,
|
|
expert_map=e_map,
|
|
global_scale1=w1_data.global_scale,
|
|
global_scale2=w2_data.global_scale,
|
|
g_idx1=w1_data.g_idx,
|
|
g_idx2=w2_data.g_idx,
|
|
sort_indices1=w1_data.sort_indices,
|
|
sort_indices2=w2_data.sort_indices,
|
|
w1_zeros=w1_data.zeros,
|
|
w2_zeros=w2_data.zeros,
|
|
quant_type_id=quant_type.id,
|
|
is_k_full=is_k_full,
|
|
)
|
|
|
|
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
|
|
|
|
|
@pytest.mark.flaky(reruns=2)
|
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
|
@pytest.mark.parametrize("m", [1, 256])
|
|
def test_fused_marlin_moe_with_bias(m):
|
|
torch.cuda.manual_seed(0)
|
|
|
|
e, topk = 32, 4
|
|
n, k = 2048, 2048
|
|
group_size = 128
|
|
act_order = False
|
|
is_k_full = True
|
|
quant_type = scalar_types.uint4b8
|
|
dtype = torch.half
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
|
|
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
|
|
|
|
w1_data = MarlinMoEWeightData.make(
|
|
w=w1,
|
|
quant_type=quant_type,
|
|
group_size=group_size,
|
|
act_order=act_order,
|
|
bias=b_bias1,
|
|
)
|
|
|
|
w2_data = MarlinMoEWeightData.make(
|
|
w=w2,
|
|
quant_type=quant_type,
|
|
group_size=group_size,
|
|
act_order=act_order,
|
|
bias=b_bias2,
|
|
)
|
|
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
torch_output = torch_moe(
|
|
a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
|
|
)
|
|
|
|
marlin_output = fused_marlin_moe(
|
|
a,
|
|
w1_data.qweight,
|
|
w2_data.qweight,
|
|
w1_data.marlin_bias,
|
|
w2_data.marlin_bias,
|
|
w1_data.scales,
|
|
w2_data.scales,
|
|
score,
|
|
topk_weights,
|
|
topk_ids,
|
|
global_num_experts=e,
|
|
expert_map=None,
|
|
global_scale1=w1_data.global_scale,
|
|
global_scale2=w2_data.global_scale,
|
|
g_idx1=w1_data.g_idx,
|
|
g_idx2=w2_data.g_idx,
|
|
sort_indices1=w1_data.sort_indices,
|
|
sort_indices2=w2_data.sort_indices,
|
|
w1_zeros=w1_data.zeros,
|
|
w2_zeros=w2_data.zeros,
|
|
quant_type_id=quant_type.id,
|
|
is_k_full=is_k_full,
|
|
)
|
|
|
|
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
|
|
|
|
|
def test_moe_align_block_size_opcheck():
|
|
num_experts = 4
|
|
block_size = 4
|
|
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
|
|
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
|
sorted_ids = torch.empty(
|
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
sorted_ids.fill_(topk_ids.numel())
|
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
|
expert_ids = torch.empty(
|
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
|
)
|
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
|
|
|
opcheck(
|
|
torch.ops._moe_C.moe_align_block_size,
|
|
(
|
|
topk_ids,
|
|
num_experts,
|
|
block_size,
|
|
sorted_ids,
|
|
expert_ids,
|
|
num_tokens_post_pad,
|
|
),
|
|
)
|
|
|
|
|
|
def test_batched_moe_align_block_size_opcheck():
|
|
max_tokens_per_batch = 512
|
|
num_experts = 4
|
|
block_size = 16
|
|
|
|
expert_num_tokens = torch.randint(
|
|
low=0,
|
|
high=max_tokens_per_batch,
|
|
size=(num_experts,),
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
|
|
max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
|
|
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
|
|
|
assert max_num_tokens_padded % block_size == 0
|
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
|
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
|
|
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")
|
|
|
|
opcheck(
|
|
torch.ops._moe_C.batched_moe_align_block_size,
|
|
(
|
|
max_tokens_per_batch,
|
|
block_size,
|
|
expert_num_tokens,
|
|
sorted_ids,
|
|
expert_ids,
|
|
num_tokens_post_pad,
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 33, 222])
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
|
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
|
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
|
actual = torch.empty((m, k), device="cuda", dtype=dtype)
|
|
|
|
expected = input.sum(dim=1)
|
|
torch.ops._moe_C.moe_sum(input, actual)
|
|
|
|
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
|
|
|
|
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 33])
|
|
@pytest.mark.parametrize("n,k", [(128, 128)])
|
|
@pytest.mark.parametrize("e", [8])
|
|
@pytest.mark.parametrize("topk", [2])
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
|
@pytest.mark.parametrize("with_bias", [False, True])
|
|
@pytest.mark.parametrize("activation", ["silu"])
|
|
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
|
|
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
|
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
|
|
|
|
device = "cpu"
|
|
torch.manual_seed(7)
|
|
|
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
|
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
|
router_logits = torch.randn((m, e), device=device, dtype=dtype)
|
|
|
|
b1 = b2 = None
|
|
if with_bias:
|
|
b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
|
|
b2 = torch.randn((e, k), device=device, dtype=dtype) / 10
|
|
|
|
ref = (
|
|
torch_moe(a, w13, w2, router_logits, topk, b1, b2)
|
|
if with_bias
|
|
else torch_moe(a, w13, w2, router_logits, topk)
|
|
)
|
|
|
|
class _Dummy(torch.nn.Module):
|
|
def __init__(self, w13, w2, b1=None, b2=None):
|
|
super().__init__()
|
|
self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
|
|
self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
|
if b1 is not None:
|
|
self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
|
|
if b2 is not None:
|
|
self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)
|
|
|
|
layer = _Dummy(w13, w2, b1, b2).to(dtype)
|
|
fused = CPUFusedMOE(layer)
|
|
out = fused(
|
|
layer=layer,
|
|
x=a,
|
|
use_grouped_topk=False,
|
|
top_k=topk,
|
|
router_logits=router_logits,
|
|
renormalize=False,
|
|
global_num_experts=e,
|
|
expert_map=None,
|
|
custom_routing_function=None,
|
|
scoring_func="softmax",
|
|
routed_scaling_factor=1.0,
|
|
e_score_correction_bias=None,
|
|
apply_router_weight_on_input=False,
|
|
activation=activation,
|
|
)
|
|
|
|
# Tolerances: fp32 tight; bf16 looser (esp. with bias)
|
|
if dtype == torch.float32:
|
|
atol = 1e-3
|
|
elif with_bias:
|
|
atol = 8e-2
|
|
else:
|
|
atol = 5e-2
|
|
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [16, 32, 64])
|
|
@pytest.mark.parametrize("n", [128])
|
|
@pytest.mark.parametrize("k", [128])
|
|
@pytest.mark.parametrize("e", [8, 12, 16, 32])
|
|
@pytest.mark.parametrize("topk", [2, 4])
|
|
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
|
|
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
|
def test_batched_fused_marlin_moe(
|
|
m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
|
|
):
|
|
print(
|
|
f"testing m={m}, n={n}, k={k}, e={e}, "
|
|
f"topk={topk}, "
|
|
f"max_tokens_per_batch={max_tokens_per_batch}"
|
|
)
|
|
torch.cuda.manual_seed(0)
|
|
|
|
dtype = torch.bfloat16
|
|
quant_dtype = scalar_types.float4_e2m1f
|
|
group_size = 32
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
|
|
|
w1_data = MarlinMoEWeightData.make(
|
|
w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
|
|
)
|
|
w2_data = MarlinMoEWeightData.make(
|
|
w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
|
|
)
|
|
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
|
|
class BatchedRun:
|
|
@staticmethod
|
|
def _make_expert_num_tokens_cpu(
|
|
e: int, # num_experts
|
|
topk_ids_cpu: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
|
|
for topk_id in torch.flatten(topk_ids_cpu):
|
|
expert_num_tokens_cpu[topk_id] += 1
|
|
return expert_num_tokens_cpu
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens_per_batch: int,
|
|
num_experts: int,
|
|
_topk_ids: torch.Tensor,
|
|
_topk_weights: torch.Tensor,
|
|
):
|
|
self.max_tokens_per_batch = max_tokens_per_batch
|
|
self.e = num_experts
|
|
self.topk_ids_cpu = _topk_ids.to("cpu")
|
|
self.topk_weights_cpu = _topk_weights.to("cpu")
|
|
self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
|
|
self.e, self.topk_ids_cpu
|
|
)
|
|
|
|
def is_valid(self):
|
|
"""
|
|
Return True only if the input can be represented in a Batched
|
|
format.
|
|
"""
|
|
return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)
|
|
|
|
def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states_cpu = hidden_states.to("cpu")
|
|
K = hidden_states_cpu.size(1)
|
|
batched_hidden_states_cpu = torch.empty(
|
|
(e, max_tokens_per_batch, K),
|
|
dtype=hidden_states_cpu.dtype,
|
|
device="cpu",
|
|
)
|
|
|
|
counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
|
|
for t_idx, token in enumerate(hidden_states_cpu):
|
|
for topk_id in self.topk_ids_cpu[t_idx]:
|
|
pos_in_batch = counter_cpu[topk_id]
|
|
batched_hidden_states_cpu[topk_id, pos_in_batch] = token
|
|
counter_cpu[topk_id] += 1
|
|
assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
|
|
return batched_hidden_states_cpu.to("cuda")
|
|
|
|
def _gather(
|
|
self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
|
|
) -> torch.Tensor:
|
|
batched_outputs_cpu = batched_outputs.to("cpu")
|
|
gather_outputs_cpu = torch.zeros_like(gather_outputs)
|
|
|
|
counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
|
|
md = gather_outputs_cpu.size(0)
|
|
for t_idx in range(md):
|
|
token = None
|
|
for topk_id, topk_weight in zip(
|
|
self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
|
|
):
|
|
pos_in_batch = counter_cpu[topk_id]
|
|
t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
|
|
if token is None:
|
|
token = t
|
|
else:
|
|
token += t
|
|
counter_cpu[topk_id] += 1
|
|
assert token is not None
|
|
gather_outputs_cpu[t_idx] = token
|
|
gather_outputs.copy_(gather_outputs_cpu)
|
|
return gather_outputs
|
|
|
|
def run(
|
|
self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
|
|
) -> torch.Tensor:
|
|
assert hidden_states.ndim == 2
|
|
assert self.is_valid()
|
|
|
|
batched_hidden_states = self._scatter(hidden_states)
|
|
|
|
kwargs = fused_marlin_moe_kwargs | {
|
|
"hidden_states": batched_hidden_states,
|
|
"expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
|
|
}
|
|
batched_outputs = batched_fused_marlin_moe(**kwargs)
|
|
|
|
output = torch.zeros_like(hidden_states)
|
|
output = self._gather(batched_outputs, output)
|
|
return output
|
|
|
|
kwargs = {
|
|
"w1": w1_data.qweight,
|
|
"w2": w2_data.qweight,
|
|
"bias1": None,
|
|
"bias2": None,
|
|
"w1_scale": w1_data.scales,
|
|
"w2_scale": w2_data.scales,
|
|
"gating_output": score,
|
|
"global_num_experts": e,
|
|
"expert_map": None,
|
|
"global_scale1": w1_data.global_scale,
|
|
"global_scale2": w2_data.global_scale,
|
|
"g_idx1": w1_data.g_idx,
|
|
"g_idx2": w2_data.g_idx,
|
|
"sort_indices1": w1_data.sort_indices,
|
|
"sort_indices2": w2_data.sort_indices,
|
|
"w1_zeros": w1_data.zeros,
|
|
"w2_zeros": w2_data.zeros,
|
|
"quant_type_id": quant_dtype.id,
|
|
"is_k_full": True,
|
|
}
|
|
|
|
# Reference
|
|
fused_marlin_moe_kwargs = kwargs | {
|
|
"hidden_states": a,
|
|
"topk_ids": topk_ids,
|
|
"topk_weights": topk_weights,
|
|
}
|
|
ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)
|
|
|
|
# Batched
|
|
br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
|
|
if not br.is_valid():
|
|
pytest.skip("Cannot represent data in Batched Format.")
|
|
marlin_output = br.run(a, kwargs)
|
|
|
|
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)
|