# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the MOE layers. Run `pytest tests/kernels/test_moe.py`. """ import functools import importlib import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any import pytest import torch from torch.nn import Parameter from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( batched_fused_marlin_moe, fused_marlin_moe, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe, ) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_permute_bias, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize, ) from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] TOP_KS = [2, 6] FUSED_MOE_MNK_FACTORS = [ (1, 128, 128), (1, 2048, 128), (33, 2048, 128), (32768, 2048, 511), (40000, 1024, 1024), ] FUSED_MOE_WN16_MNK_FACTORS = [ (1, 128, 128), (1, 1024, 1024), (32, 2048, 128), (222, 2048, 1024), ] vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 def run_moe_test( baseline: Callable | torch.Tensor, moe_fn: Callable, a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, score: torch.Tensor, topk: int, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, padding: bool = False, use_compile: bool = False, use_cudagraph: bool = False, atol: float = 2e-2, rtol: float = 0, ) -> torch.Tensor: if isinstance(baseline, torch.Tensor): baseline_output = baseline else: baseline_output = baseline( a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map, ) # Pad the weight if moe padding is enabled if padding: w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] if use_compile: moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True) torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(score, 0) test_output = moe_fn( a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map, ) if use_cudagraph: test_output.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): test_output = moe_fn( a, w1, w2, score, topk, global_num_experts=global_num_experts, expert_map=expert_map, ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) return baseline_output @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("chunk_size", [8192]) def test_fused_moe( m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, padding: bool, chunk_size: int, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) # # Setup test data # # # Setup test data # a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) if ep_size > 1: local_e = e // ep_size e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] else: e_map = None # # Setup test functions # quant_config = FUSED_MOE_UNQUANTIZED_CONFIG m_fused_moe_fn = modular_triton_fused_moe(quant_config) def m_fused_moe( a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, score: torch.Tensor, topk: int, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) return m_fused_moe_fn( a, w1, w2, topk_weights, topk_ids, global_num_experts=global_num_experts, expert_map=expert_map, ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) # # Run tests # runner = functools.partial( run_moe_test, a=a, w1=w1, w2=w2, score=score, topk=topk, global_num_experts=e, expert_map=e_map, padding=padding, ) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. use_compile = False use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) runner( baseline_output, fused_moe_fn, use_compile=use_compile, use_cudagraph=use_cudagraph, ) runner( baseline_output, m_fused_moe, use_compile=use_compile, use_cudagraph=use_cudagraph, ) @pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) def test_fused_moe_wn16( m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int, ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) if weight_bits == 4: pack_factor = 2 quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8 elif weight_bits == 8: pack_factor = 1 quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128 w1_ref = w1.clone() w2_ref = w2.clone() w1_qweight = torch.empty( (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 ) w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) w1_qzeros = torch.empty( (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 ) w2_qzeros = torch.empty( (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 ) for i in range(e * 2): expert_id = i % e if i // e == 0: w, w_ref, w_qweight, w_scales, w_qzeros = ( w1, w1_ref, w1_qweight, w1_scales, w1_qzeros, ) else: w, w_ref, w_qweight, w_scales, w_qzeros = ( w2, w2_ref, w2_qweight, w2_scales, w2_qzeros, ) weight, qweight, scales, qzeros = quantize_weights( w[expert_id].T, quant_type, group_size, has_zp, False ) weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T if has_zp: qzeros = qzeros.T.contiguous().to(torch.uint8) if weight_bits == 4: qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] if has_zp: qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] w_ref[expert_id] = weight w_qweight[expert_id] = qweight w_scales[expert_id] = scales if has_zp: w_qzeros[expert_id] = qzeros if ep_size > 1: local_e = e // ep_size e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1_ref = w1_ref[e_ids] w2_ref = w2_ref[e_ids] w1_qweight = w1_qweight[e_ids] w2_qweight = w2_qweight[e_ids] w1_scales = w1_scales[e_ids] w2_scales = w2_scales[e_ids] w1_qzeros = w1_qzeros[e_ids] w2_qzeros = w2_qzeros[e_ids] else: e_map = None if weight_bits == 4: quant_config_builder = int4_w4a16_moe_quant_config else: assert weight_bits == 8 quant_config_builder = int8_w8a16_moe_quant_config quant_config = quant_config_builder( w1_scale=w1_scales, w2_scale=w2_scales, w1_zp=w1_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None, block_shape=[0, group_size], ) with set_current_vllm_config(vllm_config): triton_output = fused_moe( a, w1_qweight, w2_qweight, score, topk, renormalize=False, global_num_experts=e, expert_map=e_map, quant_config=quant_config, ) torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] ) @torch.inference_mode() def test_mixtral_moe( dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch ): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" # clear the cache before every test # Force reload aiter_ops to pick up the new environment variables. if "rocm_aiter_ops" in sys.modules: importlib.reload(rocm_aiter_ops) if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0") monkeypatch.setenv("WORLD_SIZE", "1") monkeypatch.setenv("MASTER_ADDR", "localhost") monkeypatch.setenv("MASTER_PORT", "12345") init_distributed_environment() # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") vllm_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, params_dtype=dtype, tp_size=1, dp_size=1, ).cuda() # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = ( hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data, ) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) # Pad the weight if moe padding is enabled if padding: vllm_moe.experts.w13_weight = Parameter( F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[ ..., 0:-128 ], requires_grad=False, ) vllm_moe.experts.w2_weight = Parameter( F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False, ) torch.cuda.synchronize() torch.cuda.empty_cache() # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, torch.float16: 1e-3, torch.bfloat16: 1e-2, } if use_rocm_aiter: # The values of rtol and atol are set based on the tests in ROCM AITER package. # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 torch.testing.assert_close( hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100 ) else: torch.testing.assert_close( hf_states.flatten(0, 1), vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype], ) def marlin_moe_generate_valid_test_cases(): import itertools m_list = [1, 123, 666] n_list = [128, 1024] k_list = [256, 2048] e_list = [4, 12] topk_list = [2, 3] ep_size_list = [1, 4] dtype_list = [torch.bfloat16] group_size_list = [-1, 32, 128] act_order_list = [True, False] quant_type_list = [ scalar_types.float4_e2m1f, scalar_types.float8_e4m3fn, scalar_types.uint4, scalar_types.uint4b8, scalar_types.uint8b128, ] is_k_full_list = [True, False] all_combinations = itertools.product( m_list, n_list, k_list, e_list, topk_list, ep_size_list, dtype_list, group_size_list, act_order_list, quant_type_list, is_k_full_list, ) def is_invalid( m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full ): if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: return False if quant_type == scalar_types.float4_e2m1f: if group_size not in [16, 32]: return False if dtype == torch.float16 and group_size == 32: return False if quant_type != scalar_types.float4_e2m1f and group_size == 16: return False # Filter act_order if act_order: if group_size in (-1, k, n): return False if quant_type not in [scalar_types.uint4b8]: return False elif not is_k_full: return False return True cases = [] for case in all_combinations: if is_invalid(*case): cases.append(case) return cases @dataclass class MarlinMoEWeightData: w_ref: torch.Tensor qweight: torch.Tensor scales: torch.Tensor global_scale: torch.Tensor | None g_idx: torch.Tensor | None zeros: torch.Tensor | None sort_indices: torch.Tensor | None marlin_bias: torch.Tensor | None @staticmethod def make( w: torch.Tensor, quant_type: ScalarType, group_size: int, act_order: bool | None = None, bias: torch.Tensor | None = None, ) -> "MarlinMoEWeightData": assert w.ndim == 3 has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] k = w.shape[-1] w_ref_l: list[torch.Tensor] = [] qweight_l: list[torch.Tensor] = [] scales_l: list[torch.Tensor] = [] global_scale_l: list[torch.Tensor] = [] zeros_l: list[torch.Tensor] = [] g_idx_l: list[torch.Tensor] = [] sort_indices_l: list[torch.Tensor] = [] bias_l: list[torch.Tensor] = [] for i in range(w.shape[0]): if quant_type == scalar_types.float4_e2m1f: if group_size == 16: w_ref, qweight, scales, global_scale = ( rand_marlin_weight_nvfp4_like(w[i], group_size) ) else: w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( w[i], group_size ) global_scale = None w_ref_l.append(w_ref.T) qweight_l.append(qweight) scales_l.append(scales) if global_scale is not None: global_scale_l.append(global_scale) elif quant_type == scalar_types.float8_e4m3fn: w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) w_ref_l.append(w_ref.T) qweight_l.append(qweight) scales_l.append(scales) elif has_zp: w_ref, qweight, scales, zeros = awq_marlin_quantize( w[i].transpose(1, 0), quant_type, group_size ) w_ref_l.append(w_ref.T) qweight_l.append(qweight) scales_l.append(scales) zeros_l.append(zeros) else: test_perm = torch.randperm(k) w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm ) w_ref_l.append(w_ref.T) qweight_l.append(qweight) scales_l.append(scales) g_idx_l.append(g_idx) sort_indices_l.append(sort_indices) if bias is not None: bias_l.append(marlin_permute_bias(bias[i])) w_ref = stack_and_dev(w_ref_l) qweight = stack_and_dev(qweight_l).contiguous() scales = stack_and_dev(scales_l) global_scale = stack_and_dev(global_scale_l) if global_scale_l else None g_idx = stack_and_dev(g_idx_l) if g_idx_l else None zeros = stack_and_dev(zeros_l) if zeros_l else None sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None marlin_bias = stack_and_dev(bias_l) if bias_l else None return MarlinMoEWeightData( w_ref=w_ref, qweight=qweight, scales=scales, global_scale=global_scale, g_idx=g_idx, zeros=zeros, sort_indices=sort_indices, marlin_bias=marlin_bias, ) @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize( ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), marlin_moe_generate_valid_test_cases(), ) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, act_order: bool, quant_type: ScalarType, is_k_full: bool, ): torch.cuda.manual_seed(0) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 if ep_size > 1: local_e = e // ep_size e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] else: e_map = None w1_data = MarlinMoEWeightData.make( w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order ) w2_data = MarlinMoEWeightData.make( w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): torch_output = torch_moe( a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map ) marlin_output = fused_marlin_moe( a, w1_data.qweight, w2_data.qweight, None, None, w1_data.scales, w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=e_map, global_scale1=w1_data.global_scale, global_scale2=w2_data.global_scale, g_idx1=w1_data.g_idx, g_idx2=w2_data.g_idx, sort_indices1=w1_data.sort_indices, sort_indices2=w2_data.sort_indices, w1_zeros=w1_data.zeros, w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.parametrize("m", [1, 256]) def test_fused_marlin_moe_with_bias(m): torch.cuda.manual_seed(0) e, topk = 32, 4 n, k = 2048, 2048 group_size = 128 act_order = False is_k_full = True quant_type = scalar_types.uint4b8 dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 w1_data = MarlinMoEWeightData.make( w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order, bias=b_bias1, ) w2_data = MarlinMoEWeightData.make( w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order, bias=b_bias2, ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): torch_output = torch_moe( a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2 ) marlin_output = fused_marlin_moe( a, w1_data.qweight, w2_data.qweight, w1_data.marlin_bias, w2_data.marlin_bias, w1_data.scales, w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=None, global_scale1=w1_data.global_scale, global_scale2=w2_data.global_scale, g_idx1=w1_data.g_idx, g_idx2=w2_data.g_idx, sort_indices1=w1_data.sort_indices, sort_indices2=w2_data.sort_indices, w1_zeros=w1_data.zeros, w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) opcheck( torch.ops._moe_C.moe_align_block_size, ( topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad, ), ) def test_batched_moe_align_block_size_opcheck(): max_tokens_per_batch = 512 num_experts = 4 block_size = 16 expert_num_tokens = torch.randint( low=0, high=max_tokens_per_batch, size=(num_experts,), dtype=torch.int32, device="cuda", ) max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") assert max_num_tokens_padded % block_size == 0 max_num_m_blocks = max_num_tokens_padded // block_size expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") opcheck( torch.ops._moe_C.batched_moe_align_block_size, ( max_tokens_per_batch, block_size, expert_num_tokens, sorted_ids, expert_ids, num_tokens_post_pad, ), ) @pytest.mark.parametrize("m", [1, 33, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): input = torch.randn((m, topk, k), device="cuda", dtype=dtype) actual = torch.empty((m, k), device="cuda", dtype=dtype) expected = input.sum(dim=1) torch.ops._moe_C.moe_sum(input, actual) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) opcheck(torch.ops._moe_C.moe_sum, (input, actual)) @pytest.mark.parametrize("m", [1, 33]) @pytest.mark.parametrize("n,k", [(128, 128)]) @pytest.mark.parametrize("e", [8]) @pytest.mark.parametrize("topk", [2]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("with_bias", [False, True]) @pytest.mark.parametrize("activation", ["silu"]) @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test") def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation): from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE device = "cpu" torch.manual_seed(7) a = torch.randn((m, k), device=device, dtype=dtype) / 10 w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 router_logits = torch.randn((m, e), device=device, dtype=dtype) b1 = b2 = None if with_bias: b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10 b2 = torch.randn((e, k), device=device, dtype=dtype) / 10 ref = ( torch_moe(a, w13, w2, router_logits, topk, b1, b2) if with_bias else torch_moe(a, w13, w2, router_logits, topk) ) class _Dummy(torch.nn.Module): def __init__(self, w13, w2, b1=None, b2=None): super().__init__() self.w13_weight = torch.nn.Parameter(w13, requires_grad=False) self.w2_weight = torch.nn.Parameter(w2, requires_grad=False) if b1 is not None: self.w13_bias = torch.nn.Parameter(b1, requires_grad=False) if b2 is not None: self.w2_bias = torch.nn.Parameter(b2, requires_grad=False) layer = _Dummy(w13, w2, b1, b2).to(dtype) fused = CPUFusedMOE(layer) out = fused( layer=layer, x=a, use_grouped_topk=False, top_k=topk, router_logits=router_logits, renormalize=False, global_num_experts=e, expert_map=None, custom_routing_function=None, scoring_func="softmax", routed_scaling_factor=1.0, e_score_correction_bias=None, apply_router_weight_on_input=False, activation=activation, ) # Tolerances: fp32 tight; bf16 looser (esp. with bias) if dtype == torch.float32: atol = 1e-3 elif with_bias: atol = 8e-2 else: atol = 5e-2 torch.testing.assert_close(out, ref, atol=atol, rtol=0) @pytest.mark.parametrize("m", [16, 32, 64]) @pytest.mark.parametrize("n", [128]) @pytest.mark.parametrize("k", [128]) @pytest.mark.parametrize("e", [8, 12, 16, 32]) @pytest.mark.parametrize("topk", [2, 4]) @pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_batched_fused_marlin_moe( m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int ): print( f"testing m={m}, n={n}, k={k}, e={e}, " f"topk={topk}, " f"max_tokens_per_batch={max_tokens_per_batch}" ) torch.cuda.manual_seed(0) dtype = torch.bfloat16 quant_dtype = scalar_types.float4_e2m1f group_size = 32 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 w1_data = MarlinMoEWeightData.make( w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None ) w2_data = MarlinMoEWeightData.make( w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) class BatchedRun: @staticmethod def _make_expert_num_tokens_cpu( e: int, # num_experts topk_ids_cpu: torch.Tensor, ) -> torch.Tensor: expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu") for topk_id in torch.flatten(topk_ids_cpu): expert_num_tokens_cpu[topk_id] += 1 return expert_num_tokens_cpu def __init__( self, max_tokens_per_batch: int, num_experts: int, _topk_ids: torch.Tensor, _topk_weights: torch.Tensor, ): self.max_tokens_per_batch = max_tokens_per_batch self.e = num_experts self.topk_ids_cpu = _topk_ids.to("cpu") self.topk_weights_cpu = _topk_weights.to("cpu") self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu( self.e, self.topk_ids_cpu ) def is_valid(self): """ Return True only if the input can be represented in a Batched format. """ return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch) def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states_cpu = hidden_states.to("cpu") K = hidden_states_cpu.size(1) batched_hidden_states_cpu = torch.empty( (e, max_tokens_per_batch, K), dtype=hidden_states_cpu.dtype, device="cpu", ) counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu) for t_idx, token in enumerate(hidden_states_cpu): for topk_id in self.topk_ids_cpu[t_idx]: pos_in_batch = counter_cpu[topk_id] batched_hidden_states_cpu[topk_id, pos_in_batch] = token counter_cpu[topk_id] += 1 assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu) return batched_hidden_states_cpu.to("cuda") def _gather( self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor ) -> torch.Tensor: batched_outputs_cpu = batched_outputs.to("cpu") gather_outputs_cpu = torch.zeros_like(gather_outputs) counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32) md = gather_outputs_cpu.size(0) for t_idx in range(md): token = None for topk_id, topk_weight in zip( self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx] ): pos_in_batch = counter_cpu[topk_id] t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight if token is None: token = t else: token += t counter_cpu[topk_id] += 1 assert token is not None gather_outputs_cpu[t_idx] = token gather_outputs.copy_(gather_outputs_cpu) return gather_outputs def run( self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any] ) -> torch.Tensor: assert hidden_states.ndim == 2 assert self.is_valid() batched_hidden_states = self._scatter(hidden_states) kwargs = fused_marlin_moe_kwargs | { "hidden_states": batched_hidden_states, "expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"), } batched_outputs = batched_fused_marlin_moe(**kwargs) output = torch.zeros_like(hidden_states) output = self._gather(batched_outputs, output) return output kwargs = { "w1": w1_data.qweight, "w2": w2_data.qweight, "bias1": None, "bias2": None, "w1_scale": w1_data.scales, "w2_scale": w2_data.scales, "gating_output": score, "global_num_experts": e, "expert_map": None, "global_scale1": w1_data.global_scale, "global_scale2": w2_data.global_scale, "g_idx1": w1_data.g_idx, "g_idx2": w2_data.g_idx, "sort_indices1": w1_data.sort_indices, "sort_indices2": w2_data.sort_indices, "w1_zeros": w1_data.zeros, "w2_zeros": w2_data.zeros, "quant_type_id": quant_dtype.id, "is_k_full": True, } # Reference fused_marlin_moe_kwargs = kwargs | { "hidden_states": a, "topk_ids": topk_ids, "topk_weights": topk_weights, } ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs) # Batched br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights) if not br.is_valid(): pytest.skip("Cannot represent data in Batched Format.") marlin_output = br.run(a, kwargs) torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)