mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 09:19:11 +08:00
[MoE-FP8-modelopt] Add FlashInfer alignment padding for intermediate dimensions (#29748)
Signed-off-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster> Signed-off-by: dafrimi <dafrimi@nvidia.com> Co-authored-by: Daniel Afrimi <dafrimi@pool0-00589.cm.cluster> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
6ec0d8dbe4
commit
13618626df
@ -81,6 +81,7 @@ from vllm.utils.flashinfer import (
|
|||||||
has_flashinfer,
|
has_flashinfer,
|
||||||
has_flashinfer_moe,
|
has_flashinfer_moe,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
@ -607,6 +608,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.flashinfer_moe_backend is not None:
|
||||||
|
self._maybe_pad_intermediate_for_flashinfer(layer)
|
||||||
|
|
||||||
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
||||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
@ -684,6 +688,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||||
register_moe_scaling_factors(layer)
|
register_moe_scaling_factors(layer)
|
||||||
|
|
||||||
|
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||||
|
|
||||||
|
Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
|
||||||
|
used for GEMM to be divisible by a small alignment value. When this is
|
||||||
|
not satisfied (e.g. with certain tensor-parallel sizes), we pad the
|
||||||
|
gate/up and down projection weights along the intermediate dim.
|
||||||
|
"""
|
||||||
|
if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Current local intermediate size (per partition) is the K dimension of
|
||||||
|
# the down projection.
|
||||||
|
num_experts, hidden_size, intermediate = layer.w2_weight.shape
|
||||||
|
|
||||||
|
min_alignment = 16
|
||||||
|
padded_intermediate = round_up(intermediate, min_alignment)
|
||||||
|
|
||||||
|
if padded_intermediate == intermediate:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Padding intermediate size from %d to %d for up/down projection weights.",
|
||||||
|
intermediate,
|
||||||
|
padded_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
up_mult = 2 if self.moe.is_act_and_mul else 1
|
||||||
|
padded_gate_up_dim = up_mult * padded_intermediate
|
||||||
|
|
||||||
|
# Pad w13 and w12 along its intermediate dimension.
|
||||||
|
w13 = layer.w13_weight.data
|
||||||
|
padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
|
||||||
|
padded_w13[:, : w13.shape[1], :] = w13
|
||||||
|
layer.w13_weight.data = padded_w13
|
||||||
|
|
||||||
|
w2 = layer.w2_weight.data
|
||||||
|
padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
|
||||||
|
padded_w2[:, :, :intermediate] = w2
|
||||||
|
layer.w2_weight.data = padded_w2
|
||||||
|
|
||||||
|
if hasattr(layer, "intermediate_size_per_partition"):
|
||||||
|
layer.intermediate_size_per_partition = padded_intermediate
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user