[gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading (#24154)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei 2025-09-09 21:27:53 -07:00 committed by GitHub
parent 53b42f4102
commit 0efdb5c3ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 145 additions and 34 deletions

View File

@ -24,6 +24,8 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
@dataclass
@ -204,6 +206,7 @@ def tg_mxfp4_moe(
alpha,
beta,
limit,
transpose_optimized: bool = False,
) -> torch.Tensor:
sf_block_size = 32
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
@ -267,22 +270,85 @@ def tg_mxfp4_moe(
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m))
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
if transpose_optimized:
for i in range(num_experts):
# w13 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_shuffled.append(
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_shuffled.append(
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m))
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
else:
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
@ -356,6 +422,7 @@ def check_accuracy(a, b, atol, rtol, percent):
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
@pytest.mark.parametrize("transpose_optimized", [False, True])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP4_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test")
@ -369,6 +436,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
beta: float,
limit: Optional[float],
act_type: str,
transpose_optimized: bool,
):
seed = 42
torch.manual_seed(seed)
@ -470,6 +538,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
act_type,
alpha=alpha,
beta=beta,
limit=limit)
limit=limit,
transpose_optimized=transpose_optimized)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)

View File

@ -122,6 +122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
@ -266,7 +267,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif should_use_flashinfer_mxfp4():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
from flashinfer.fp4_quantization import (
nvfp4_block_scale_interleave)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices)
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
@ -343,25 +347,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
gemm1_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
# w13 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
gemm2_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1,
1)[permute_bias_indices.to(w13_bias.device)].contiguous())
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = torch.stack(