mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 12:05:59 +08:00
[gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading (#24154)
Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
53b42f4102
commit
0efdb5c3ba
@ -24,6 +24,8 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -204,6 +206,7 @@ def tg_mxfp4_moe(
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
transpose_optimized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
|
||||
@ -267,22 +270,85 @@ def tg_mxfp4_moe(
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m))
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
if transpose_optimized:
|
||||
for i in range(num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(w13_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w13_weight.device)].contiguous())
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w13_weight_scale.device)].contiguous()))
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
||||
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(w2_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w2_weight.device)].contiguous())
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w2_weight_scale.device)].contiguous()))
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
||||
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m))
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
else:
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
|
||||
@ -356,6 +422,7 @@ def check_accuracy(a, b, atol, rtol, percent):
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
|
||||
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test")
|
||||
@ -369,6 +436,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
beta: float,
|
||||
limit: Optional[float],
|
||||
act_type: str,
|
||||
transpose_optimized: bool,
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
@ -470,6 +538,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit)
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
@ -122,6 +122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
@ -266,7 +267,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
from flashinfer.fp4_quantization import (
|
||||
nvfp4_block_scale_interleave)
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w2_permute_indices)
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
@ -343,25 +347,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
# w13 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w13_weight.device)].contiguous())
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w13_weight_scale.device)].contiguous()))
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
||||
-1,
|
||||
1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w2_weight.device)].contiguous())
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w2_weight_scale.device)].contiguous()))
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
self._cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
||||
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = torch.stack(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user