# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch from tests.kernels.moe.utils import make_test_weights from tests.kernels.quantization.nvfp4_utils import ( FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype, ) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform if not current_platform.has_device_capability(100): pytest.skip( "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True ) MNK_FACTORS = [ (2, 1024, 1024), (2, 1024, 1536), (2, 3072, 1024), (2, 3072, 1536), (64, 1024, 1024), (64, 1024, 1536), (64, 3072, 1024), (64, 2048, 1536), (224, 1024, 1024), (224, 1024, 1536), ] @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() def test_cutlass_fp4_moe_no_graph( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype ): current_platform.seed_everything(7) with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): quant_blocksize = 16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = ( make_test_weights( e, n, k, in_dtype=dtype, quant_dtype="nvfp4", block_shape=None, # use quant_blocksize? per_out_ch_quant=False, ) ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) assert w1_gs is not None assert w2_gs is not None assert w1_blockscale is not None assert w2_blockscale is not None quant_config = nvfp4_moe_quant_config( g1_alphas=(1 / w1_gs), g2_alphas=(1 / w2_gs), a1_gscale=a1_gs, a2_gscale=a2_gs, w1_scale=w1_blockscale, w2_scale=w2_blockscale, ) cutlass_output = cutlass_moe_fp4( a=a, w1_fp4=w1_q, w2_fp4=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, quant_config=quant_config, m=m, n=n, k=k, e=e, ) # Reference check: a_global_scale = ( (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) a_in_dtype = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, a_global_scale, dtype=a.dtype, device=a.device, block_size=quant_blocksize, ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): w1_d[idx] = dequantize_nvfp4_to_dtype( w1_q[idx], w1_blockscale[idx], w1_gs[idx], dtype=dtype, device=w1_q.device, block_size=quant_blocksize, ) w2_d[idx] = dequantize_nvfp4_to_dtype( w2_q[idx], w2_blockscale[idx], w2_gs[idx], dtype=dtype, device=w2_q.device, block_size=quant_blocksize, ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) if __name__ == "__main__": test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)