[Bugfix] Fix accuracy issue when using flashinfer cutlass moe, TP=1 and modelopt. (#23125)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm 2025-08-19 14:00:51 -04:00 committed by GitHub
parent 5b5f350d67
commit b94faf9d50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 34 deletions

View File

@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
@ -181,3 +183,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank=self.ep_rank,
output=output,
)
def flashinfer_cutlass_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
a1_gscale=a1_gscale),
FlashInferExperts(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=hidden_states.dtype,
quant_dtype="nvfp4",
))
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -198,6 +198,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
else:
return None
# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)

View File

@ -388,6 +388,33 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
apply_router_weight_on_input=apply_router_weight_on_input,
)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")

View File

@ -966,22 +966,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}")
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return super().maybe_make_prepare_finalize(moe)
if (self.allow_flashinfer and self.flashinfer_moe_backend
== FlashinferMoeBackend.CUTLASS):
prepare_finalize = (
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
))
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(moe)
def select_gemm_impl(
self,
@ -1409,7 +1408,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=global_num_experts,
expert_map=expert_map)
if self.fused_experts is None:
if self.fused_experts is not None:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4)
out = flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
@ -1432,27 +1476,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e=layer.w13_weight.shape[0],
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)
return out