[MoE-FP8-modelopt] Add FlashInfer alignment padding for intermediate dimensions (#29748)

Signed-off-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster>
Signed-off-by: dafrimi <dafrimi@nvidia.com>
Co-authored-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
danielafrimi 2025-12-12 22:42:32 +02:00 committed by GitHub
parent 6ec0d8dbe4
commit 13618626df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -81,6 +81,7 @@ from vllm.utils.flashinfer import (
has_flashinfer,
has_flashinfer_moe,
)
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@ -607,6 +608,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
if self.flashinfer_moe_backend is not None:
self._maybe_pad_intermediate_for_flashinfer(layer)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
@ -684,6 +688,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
used for GEMM to be divisible by a small alignment value. When this is
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
gate/up and down projection weights along the intermediate dim.
"""
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
return
# Current local intermediate size (per partition) is the K dimension of
# the down projection.
num_experts, hidden_size, intermediate = layer.w2_weight.shape
min_alignment = 16
padded_intermediate = round_up(intermediate, min_alignment)
if padded_intermediate == intermediate:
return
logger.info(
"Padding intermediate size from %d to %d for up/down projection weights.",
intermediate,
padded_intermediate,
)
up_mult = 2 if self.moe.is_act_and_mul else 1
padded_gate_up_dim = up_mult * padded_intermediate
# Pad w13 and w12 along its intermediate dimension.
w13 = layer.w13_weight.data
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
padded_w13[:, : w13.shape[1], :] = w13
layer.w13_weight.data = padded_w13
w2 = layer.w2_weight.data
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
padded_w2[:, :, :intermediate] = w2
layer.w2_weight.data = padded_w2
if hasattr(layer, "intermediate_size_per_partition"):
layer.intermediate_size_per_partition = padded_intermediate
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: