From 0efdb5c3bad240121d91083189565df0b82502c2 Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 9 Sep 2025 21:27:53 -0700 Subject: [PATCH] [gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading (#24154) Signed-off-by: Wei Wei --- tests/kernels/moe/test_mxfp4_moe.py | 101 +++++++++++++++--- .../layers/quantization/mxfp4.py | 78 ++++++++++---- 2 files changed, 145 insertions(+), 34 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index c29bed3dd6b32..882b034e2f230 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -24,6 +24,8 @@ if TRTLLM_GEN_MXFP4_AVAILABLE: next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices @dataclass @@ -204,6 +206,7 @@ def tg_mxfp4_moe( alpha, beta, limit, + transpose_optimized: bool = False, ) -> torch.Tensor: sf_block_size = 32 assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts @@ -267,22 +270,85 @@ def tg_mxfp4_moe( gemm1_bias_shuffled = [] gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - for i in range(num_experts): - gemm1_weights_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm1_scales_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + if transpose_optimized: + for i in range(num_experts): + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_shuffled.append( + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) - gemm2_weights_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm2_scales_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + else: + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) w13_weight = torch.stack(gemm1_weights_shuffled) w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( @@ -356,6 +422,7 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.parametrize("transpose_optimized", [False, True]) @pytest.mark.skipif( not TRTLLM_GEN_MXFP4_AVAILABLE, reason="nvidia gpu and compute capability sm100 is required for this test") @@ -369,6 +436,7 @@ def test_trtllm_gen_mxfp4_fused_moe( beta: float, limit: Optional[float], act_type: str, + transpose_optimized: bool, ): seed = 42 torch.manual_seed(seed) @@ -470,6 +538,7 @@ def test_trtllm_gen_mxfp4_fused_moe( act_type, alpha=alpha, beta=beta, - limit=limit) + limit=limit, + transpose_optimized=transpose_optimized) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 889c15df3c878..74922756afe56 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -122,6 +122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): "MXFP4 MoE is enabled on Blackwell but FlashInfer " "is not available. This may result in degraded performance. " "Please `pip install vllm[flashinfer]` for best results.") + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} def _should_use_marlin(self): if envs.VLLM_MXFP4_USE_MARLIN is not None: @@ -266,7 +267,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) elif should_use_flashinfer_mxfp4(): - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + from flashinfer.fp4_quantization import ( + nvfp4_block_scale_interleave) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices) layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -343,25 +347,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): - gemm1_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), - epilogue_tile_m)) + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm1_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) - - gemm2_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, + 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) w13_weight_scale = torch.stack(