# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.metadata from dataclasses import dataclass from importlib.util import find_spec import pytest import torch from packaging import version from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( importlib.metadata.version("amd-quark") ) >= version.parse("0.8.99") TRTLLM_GEN_MXFP4_AVAILABLE = ( current_platform.is_cuda() and current_platform.is_device_capability(100) ) HOPPER_MXFP4_BF16_AVAILABLE = ( current_platform.is_cuda() and current_platform.is_device_capability(90) and has_flashinfer() ) if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import ( fp4_quantize, mxfp8_quantize, next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe, ) from flashinfer.fp4_quantization import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache @dataclass class ModelCase: model_id: str tp: int @pytest.fixture(scope="function", autouse=True) def enable_pickle(monkeypatch): """`LLM.apply_model` requires pickling a function.""" monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.mark.parametrize( "model_case", [ ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2), ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1), ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1), ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4), ], ) @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): if torch.cuda.device_count() < model_case.tp: pytest.skip( f"This test requires >={model_case.tp} gpus, got only " f"{torch.cuda.device_count()}" ) # `cuda_graph_sizes=[16]` to reduce load time. with vllm_runner( model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy", cuda_graph_sizes=[16], ) as llm: # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562 # def check_model(model): # from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 # QuarkLinearMethod) # from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501 # from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 # QuarkOCP_MX_MoEMethod) # layer = model.model.layers[0] # qkv_proj = layer.self_attn.qkv_proj # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) # assert isinstance(qkv_proj.scheme, QuarkOCP_MX) # assert isinstance(layer.mlp.experts.quant_method, # QuarkOCP_MX_MoEMethod) # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": # llm.apply_model(check_model) output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) assert output def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None): # Note we add an extra bias of 1 to the linear layer x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) return out_glu * (x_linear + beta) fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6] def mxfp4_dequantize(x, scale): assert x.dtype == torch.uint8 x = x.view(torch.uint8).to(torch.int32) x_unpacked = torch.zeros( *x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device ) x_unpacked[..., 0::2].copy_(x & 0xF) x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device) for i, val in enumerate(fp4_lookup_table): x_float[x_unpacked == i] = val scale = scale.view(torch.uint8).to(torch.int32) scale = (scale << 23).view(torch.float32) scale = scale.reshape(*x.shape[:-1], -1) scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) return x_float * scale def mxfp8_dequantize(x, scale): assert x.dtype == torch.float8_e4m3fn x_float = x.to(torch.float32) scale = scale.view(torch.uint8).to(torch.int32) scale = (scale << 23).view(torch.float32) scale = scale.reshape(*x.shape[:-1], -1) scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape) return x_float * scale def reference_moe( roouting_logits, topk, num_experts, hidden_states, w13, bias13, w2, bias2, alpha, beta, limit, act_type, ): # renormalize routing experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) expert_weights = torch.nn.functional.softmax(experts.values, dim=1) expert_indices = experts.indices t = hidden_states.clone() # MLP #1 mlp1_weight = w13[expert_indices, ...] mlp1_bias = bias13[expert_indices, ...] t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias t = swiglu(t, alpha=alpha, beta=beta, limit=limit) if act_type == "mxfp8": t_quantized, t_scale = mxfp8_quantize( t.to(torch.bfloat16), is_sf_swizzled_layout=False ) t = mxfp8_dequantize(t_quantized, t_scale) # MLP #2 mlp2_weight = w2[expert_indices, ...] mlp2_bias = bias2[expert_indices, ...] t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias # Weighted sum of experts t = torch.einsum("bec,be->bc", t, expert_weights) assert t.shape == hidden_states.shape return t.to(torch.bfloat16) def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int): # Number of tokens in the input tensor. num_tokens = x.shape[0] # Factor to account for the imbalance of the experts. # factor equals to the # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert # - 1.0 means perfect expert distribution. # - > 1.0 means some experts have more # tokens than the perfect distribution. # - < 1.0 does not make sense. imbalance_factor = 1.3 # Calculate the number of tokens per expert # assuming perfect distribution. num_tokens_per_expert = (num_tokens * top_k) // num_experts # Apply the imbalance factor. num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) # And pad the number to the next power of 2. tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) # Cap to 8-64 tokens per CTA tile # as it's the range supported by the kernel. tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim def tg_mxfp4_moe( router_logits, topk, num_experts, intermediate_size, hidden_size, hidden_states, hidden_states_scale, w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias, act_type, alpha, beta, limit, transpose_optimized: bool = False, ) -> torch.Tensor: sf_block_size = 32 assert ( w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts and w13_weight.shape[1] == intermediate_size * 2 and w13_weight.shape[2] == hidden_size // 2 ) assert ( w13_weight_scale.dim() == 3 and w13_weight_scale.shape[0] == num_experts and w13_weight_scale.shape[1] == intermediate_size * 2 and w13_weight_scale.shape[2] == hidden_size // sf_block_size ) assert ( w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts and w2_weight.shape[1] == hidden_size and w2_weight.shape[2] == intermediate_size // 2 ) assert ( w2_weight_scale.dim() == 3 and w2_weight_scale.shape[1] == hidden_size and w2_weight_scale.shape[2] == intermediate_size // sf_block_size ) assert ( w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts and w13_bias.shape[1] == intermediate_size * 2 ) assert ( w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts and w2_bias.shape[1] == hidden_size ) # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen w13_weight_scale_ = w13_weight_scale.clone() w13_weight_ = w13_weight.clone() w13_bias_ = w13_bias.clone() w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :]) w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :]) w13_weight_scale[:, :intermediate_size, :].copy_( w13_weight_scale_[:, intermediate_size:, :] ) w13_weight_scale[:, intermediate_size:, :].copy_( w13_weight_scale_[:, :intermediate_size, :] ) w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) # Interleave the weights and scaling factors for activation w13_weight_interleaved = [] w13_weight_scale_interleaved = [] w13_bias_interleaved = [] for i in range(num_experts): w13_weight_interleaved.append( reorder_rows_for_gated_act_gemm(w13_weight[i].clone()) ) w13_weight_scale_interleaved.append( reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()) ) w13_bias_interleaved.append( reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1)) ) w13_weight = torch.stack(w13_weight_interleaved).reshape( num_experts, 2 * intermediate_size, hidden_size // 2 ) w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( num_experts, 2 * intermediate_size, hidden_size // 32 ) w13_bias = torch.stack(w13_bias_interleaved).reshape( num_experts, 2 * intermediate_size ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_shuffled = [] gemm1_scales_shuffled = [] gemm2_weights_shuffled = [] gemm2_scales_shuffled = [] gemm1_bias_shuffled = [] gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} if transpose_optimized: for i in range(num_experts): # w13 weight shuffling permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, ) gemm1_weights_shuffled.append( w13_weight[i] .view(torch.uint8)[permute_indices.to(w13_weight.device)] .contiguous() ) # w13 scale shuffling permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm1_scales_shuffled.append( nvfp4_block_scale_interleave( w13_weight_scale[i] .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)] .contiguous() ) ) # w13 bias shuffling permute_bias_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) gemm1_bias_shuffled.append( w13_bias[i] .clone() .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] .contiguous() ) # w2 weight shuffling permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, ) gemm2_weights_shuffled.append( w2_weight[i] .view(torch.uint8)[permute_indices.to(w2_weight.device)] .contiguous() ) # w2 scale shuffling permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm2_scales_shuffled.append( nvfp4_block_scale_interleave( w2_weight_scale[i] .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)] .contiguous() ) ) # w2 bias shuffling permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) gemm2_bias_shuffled.append( w2_bias[i] .clone() .reshape(-1, 1)[permute_indices.to(w2_bias.device)] .contiguous() ) else: for i in range(num_experts): gemm1_weights_shuffled.append( shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m) ) gemm1_scales_shuffled.append( shuffle_matrix_sf_a( w13_weight_scale[i].view(torch.uint8), epilogue_tile_m ) ) gemm2_weights_shuffled.append( shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) ) gemm2_scales_shuffled.append( shuffle_matrix_sf_a( w2_weight_scale[i].view(torch.uint8), epilogue_tile_m ) ) gemm1_bias_shuffled.append( shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m) ) gemm2_bias_shuffled.append( shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m) ) w13_weight = torch.stack(gemm1_weights_shuffled) w13_weight_scale = ( torch.stack(gemm1_scales_shuffled) .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) .view(torch.float8_e4m3fn) ) w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) w2_weight = torch.stack(gemm2_weights_shuffled) w2_weight_scale = ( torch.stack(gemm2_scales_shuffled) .reshape(num_experts, hidden_size, intermediate_size // sf_block_size) .view(torch.float8_e4m3fn) ) w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) tg_result = trtllm_fp4_block_scale_moe( routing_logits=router_logits.to(torch.bfloat16), routing_bias=None, hidden_states=hidden_states, hidden_states_scale=hidden_states_scale, gemm1_weights=w13_weight, gemm1_weights_scale=w13_weight_scale, gemm1_bias=w13_bias, gemm1_alpha=alpha, gemm1_beta=beta, gemm1_clamp_limit=limit, gemm2_weights=w2_weight, gemm2_weights_scale=w2_weight_scale, gemm2_bias=w2_bias, output1_scale_scalar=None, output1_scale_gate_scalar=None, output2_scale_scalar=None, num_experts=num_experts, top_k=topk, n_group=None, topk_group=None, intermediate_size=intermediate_size, local_expert_offset=0, local_num_experts=num_experts, routed_scaling_factor=None, tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), routing_method_type=1, # renormalize do_finalize=True, )[0] return tg_result def check_accuracy(a, b, atol, rtol, percent): """Allow a mismatch percentage of 1 - percent.""" if torch.any(torch.isnan(a)): raise Exception("NaN in reference output") if torch.any(torch.isnan(b)): raise Exception("NaN in actual output") if torch.any(torch.isinf(a)): raise Exception("Inf in reference output") if torch.any(torch.isinf(b)): raise Exception("Inf in actual output") assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}" left = torch.abs(a - b) right = atol + rtol * torch.abs(b) count = torch.sum(left > right) mismatch_percent = count / a.numel() if mismatch_percent > 1 - percent: raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " f"(threshold: {1 - percent:.4f})" ) @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32, 128]) @pytest.mark.parametrize("num_tokens", [1, 128, 1024]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ["mxfp8", "bf16"]) @pytest.mark.parametrize("transpose_optimized", [False, True]) @pytest.mark.skipif( not TRTLLM_GEN_MXFP4_AVAILABLE, reason="nvidia gpu and compute capability sm100 is required for this test", ) def test_trtllm_gen_mxfp4_fused_moe( topk: int, num_experts: int, num_tokens: int, intermediate_size: int, hidden_size: int, alpha: float, beta: float, limit: float | None, act_type: str, transpose_optimized: bool, ): seed = 42 torch.manual_seed(seed) hidden_states = torch.randn( num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16 ) w13 = torch.randn( num_experts, intermediate_size * 2, hidden_size, device="cuda:0", dtype=torch.bfloat16, ) w2 = torch.randn( num_experts, hidden_size, intermediate_size, device="cuda:0", dtype=torch.bfloat16, ) bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10 bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda() w13, w13_scale = fp4_quantize( w13, torch.tensor(1.0, device="cuda:0"), 32, sf_use_ue8m0=True, is_sf_swizzled_layout=False, ) w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( num_experts, intermediate_size * 2, hidden_size // 32 ) w2, w2_scale = fp4_quantize( w2, torch.tensor(1.0, device="cuda:0"), 32, sf_use_ue8m0=True, is_sf_swizzled_layout=False, ) w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( num_experts, hidden_size, intermediate_size // 32 ) if act_type == "mxfp8": hidden_states, hidden_states_scale = mxfp8_quantize( hidden_states, is_sf_swizzled_layout=False ) hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1) else: hidden_states_scale = None # reference result ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16) w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone()) w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) bias13_ref = bias13 bias2_ref = bias2 if act_type == "mxfp8": hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to( torch.float32 ) else: hidden_states_ref = hidden_states.to(torch.float32) # Process tokens in chunks of 32 to reduce memory usage chunk_size = 32 num_chunks = (num_tokens + chunk_size - 1) // chunk_size for i in range(num_chunks): start_idx = i * chunk_size end_idx = min(start_idx + chunk_size, num_tokens) chunk_result = reference_moe( router_logits[start_idx:end_idx].to(torch.float32), topk, num_experts, hidden_states_ref[start_idx:end_idx], w13_ref, bias13_ref, w2_ref, bias2_ref, alpha, beta, limit, act_type, ) ref_result[start_idx:end_idx].copy_(chunk_result) # trtllm-gen result if alpha is not None: alpha = torch.full((num_experts,), alpha, device=hidden_states.device) if limit is not None: limit = torch.full((num_experts,), limit, device=hidden_states.device) if beta is not None: beta = torch.full((num_experts,), beta, device=hidden_states.device) tg_result = tg_mxfp4_moe( router_logits, topk, num_experts, intermediate_size, hidden_size, hidden_states, hidden_states_scale, w13, w13_scale, bias13, w2, w2_scale, bias2, act_type, alpha=alpha, beta=beta, limit=limit, transpose_optimized=transpose_optimized, ) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: """Interleave scales on the last dimension by groups of 4, matching the transformation in mxfp4.py's BF16 (Hopper) path.""" s = scales.to(torch.uint8) s_shape = s.shape assert s_shape[-1] % 4 == 0 s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) # Move the 4-group dimension before the row dimension permuted = s.permute(0, 2, 1, 3) # Merge the row dim with the 4-group dim return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_tokens", [1, 128]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.skipif( not HOPPER_MXFP4_BF16_AVAILABLE, reason="nvidia gpu sm90 and flashinfer are required for this test", ) def test_flashinfer_cutlass_mxfp4_fused_moe( topk: int, num_experts: int, num_tokens: int, intermediate_size: int, hidden_size: int, alpha: float, beta: float, limit: float | None, ): torch.manual_seed(42) device = "cuda:0" # Inputs hidden_states = torch.randn( num_tokens, hidden_size, device=device, dtype=torch.bfloat16 ) # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] w13_q = torch.randint( 0, 256, (num_experts, 2 * intermediate_size, hidden_size // 2), device=device, dtype=torch.uint8, ) w13_scale = torch.randint( 118, 123, (num_experts, 2 * intermediate_size, hidden_size // 32), device=device, dtype=torch.uint8, ) w2_q = torch.randint( 0, 256, (num_experts, hidden_size, intermediate_size // 2), device=device, dtype=torch.uint8, ) w2_scale = torch.randint( 118, 123, (num_experts, hidden_size, intermediate_size // 32), device=device, dtype=torch.uint8, ) # Bias contiguous [b1; b3] bias13 = ( torch.randn( num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 ) * 10 ) bias2 = ( torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 ) router_logits = torch.rand( num_tokens, num_experts, dtype=torch.float32, device=device ) w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( num_experts, 2 * intermediate_size, hidden_size ) w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( num_experts, hidden_size, intermediate_size ) ref = reference_moe( router_logits.to(torch.float32), topk, num_experts, hidden_states.to(torch.float32), w13_ref, bias13.to(torch.float32), w2_ref, bias2.to(torch.float32), alpha, beta, limit, "bf16", ) from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe # Swap halves to arrange as [w3; w1] (kernel expectation) w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) w13_s = torch.cat([w3_s, w1_s], dim=1) w13_s_inter = _interleave_scales_lastdim_by4(w13_s) w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) routing_weights = torch.nn.functional.softmax( router_logits, dim=1, dtype=torch.float32 ) token_final_scales, token_selected_experts = torch.topk( routing_weights, topk, dim=-1 ) token_final_scales = token_final_scales / token_final_scales.sum( dim=-1, keepdim=True ) token_selected_experts = token_selected_experts.to(torch.int).contiguous() out = torch.empty_like(hidden_states, dtype=torch.bfloat16) if alpha is not None: alpha = torch.full((num_experts,), alpha, device=hidden_states.device) if beta is not None: beta = torch.full((num_experts,), beta, device=hidden_states.device) if limit is not None: limit = torch.full((num_experts,), limit, device=hidden_states.device) _ = flashinfer_cutlass_fused_moe( input=hidden_states, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, fc1_expert_weights=w13_q_swapped, fc2_expert_weights=w2_q, output_dtype=torch.bfloat16, output=out, quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)], fc1_expert_biases=w13_b, fc2_expert_biases=bias2.to(torch.bfloat16), swiglu_alpha=alpha, swiglu_beta=beta, swiglu_limit=limit, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, use_w4_group_scaling=True, ) # Allow some mismatch due to MXFP4 quantization check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_tokens", [1, 128]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.skipif( not ( current_platform.is_cuda() and current_platform.is_device_capability(100) and has_flashinfer() ), reason="NVIDIA GPU sm100 and flashinfer are required for this test", ) def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( topk: int, num_experts: int, num_tokens: int, intermediate_size: int, hidden_size: int, alpha: float | None, beta: float | None, limit: float | None, ): torch.manual_seed(42) device = "cuda:0" # Inputs hidden_states = torch.randn( num_tokens, hidden_size, device=device, dtype=torch.bfloat16 ) # Float weights in w13 format [w1; w3] w13 = ( torch.randn( num_experts, 2 * intermediate_size, hidden_size, device=device, dtype=torch.bfloat16, ) / 10 ) w2 = ( torch.randn( num_experts, hidden_size, intermediate_size, device=device, dtype=torch.bfloat16, ) / 10 ) # Bias contiguous [b1; b3] bias13 = ( torch.randn( num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 ) * 10 ) bias2 = ( torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 ) router_logits = torch.rand( num_tokens, num_experts, dtype=torch.float32, device=device ) # Quantize weights to MXFP4 per expert (SM100 path) from flashinfer import mxfp4_quantize def quant_mxfp4_batches(a: torch.Tensor, e: int): qs, sfs = [], [] for i in range(e): q, sf = mxfp4_quantize(a[i].cuda()) qs.append(q) sfs.append(sf) return torch.stack(qs), torch.stack(sfs) def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor): num_batches = mat_fp4.size(0) scale_tensor = scale_tensor.view(num_batches, -1) from flashinfer import mxfp4_dequantize return torch.stack( [ mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) for b in range(num_batches) ] ) w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) # Reference result using dequantized tensors and reference_moe w13_ref = ( dequant_mxfp4_batches( w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1) ) .to(torch.float32) .reshape(num_experts, 2 * intermediate_size, hidden_size) .to(device) ) w2_ref = ( dequant_mxfp4_batches( w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1) ) .to(torch.float32) .reshape(num_experts, hidden_size, intermediate_size) .to(device) ) # Quantize activations for SM100 path and dequantize for reference hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) # Reference uses BF16 input but quantizes intermediate activation to MXFP8 ref = reference_moe( router_logits.to(torch.float32), topk, num_experts, hidden_states.to(torch.float32), w13_ref, bias13.to(torch.float32), w2_ref, bias2.to(torch.float32), alpha, beta, limit, "mxfp8", ) # Prepare inputs for FlashInfer CUTLASS fused MoE from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe # Swap halves to arrange as [w3; w1] (kernel expectation) w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) # Swap scales halves to match swapped weights s1, s3 = torch.chunk(w13_scale, 2, dim=1) w13_scale_swapped = torch.cat([s3, s1], dim=1) b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) # Build routing for kernel routing_weights = torch.nn.functional.softmax( router_logits, dim=1, dtype=torch.float32 ) token_final_scales, token_selected_experts = torch.topk( routing_weights, topk, dim=-1 ) token_final_scales = token_final_scales / token_final_scales.sum( dim=-1, keepdim=True ) token_selected_experts = token_selected_experts.to(torch.int).contiguous() out = torch.empty_like(hidden_states, dtype=torch.bfloat16) if alpha is not None: alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device) else: alpha_t = None if beta is not None: beta_t = torch.full((num_experts,), beta, device=hidden_states.device) else: beta_t = None if limit is not None: limit_t = torch.full((num_experts,), limit, device=hidden_states.device) else: limit_t = None # Quant scales for SM100 MXFP8+MXFP4 path fake_input_scale = torch.ones(num_experts, device=device) quant_scales = [ w13_scale_swapped.view(torch.int32), fake_input_scale, w2_scale.view(torch.int32), fake_input_scale, ] _ = flashinfer_cutlass_fused_moe( input=hidden_states_q, token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), fc2_expert_weights=w2_q.contiguous().view(torch.long), output_dtype=torch.bfloat16, output=out, quant_scales=quant_scales, fc1_expert_biases=w13_b, fc2_expert_biases=bias2.to(torch.bfloat16), swiglu_alpha=alpha_t, swiglu_beta=beta_t, swiglu_limit=limit_t, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, use_mxfp8_act_scaling=True, input_sf=hidden_states_sf, ) # Allow some mismatch due to MXFP4 quantization check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)