mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 16:40:53 +08:00
[MoE-FP8-modelopt] Add FlashInfer alignment padding for intermediate dimensions (#29748)
Signed-off-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster> Signed-off-by: dafrimi <dafrimi@nvidia.com> Co-authored-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
6ec0d8dbe4
commit
13618626df
@ -81,6 +81,7 @@ from vllm.utils.flashinfer import (
|
||||
has_flashinfer,
|
||||
has_flashinfer_moe,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@ -607,6 +608,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||
"""
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
self._maybe_pad_intermediate_for_flashinfer(layer)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
|
||||
@ -684,6 +688,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
|
||||
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
||||
used for GEMM to be divisible by a small alignment value. When this is
|
||||
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||
gate/up and down projection weights along the intermediate dim.
|
||||
"""
|
||||
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
|
||||
return
|
||||
|
||||
# Current local intermediate size (per partition) is the K dimension of
|
||||
# the down projection.
|
||||
num_experts, hidden_size, intermediate = layer.w2_weight.shape
|
||||
|
||||
min_alignment = 16
|
||||
padded_intermediate = round_up(intermediate, min_alignment)
|
||||
|
||||
if padded_intermediate == intermediate:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||
intermediate,
|
||||
padded_intermediate,
|
||||
)
|
||||
|
||||
up_mult = 2 if self.moe.is_act_and_mul else 1
|
||||
padded_gate_up_dim = up_mult * padded_intermediate
|
||||
|
||||
# Pad w13 and w12 along its intermediate dimension.
|
||||
w13 = layer.w13_weight.data
|
||||
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
||||
padded_w13[:, : w13.shape[1], :] = w13
|
||||
layer.w13_weight.data = padded_w13
|
||||
|
||||
w2 = layer.w2_weight.data
|
||||
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
||||
padded_w2[:, :, :intermediate] = w2
|
||||
layer.w2_weight.data = padded_w2
|
||||
|
||||
if hasattr(layer, "intermediate_size_per_partition"):
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user