mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 05:58:42 +08:00
[Bugfix] Fix accuracy issue when using flashinfer cutlass moe, TP=1 and modelopt. (#23125)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
5b5f350d67
commit
b94faf9d50
@ -7,6 +7,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
||||
@ -181,3 +183,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe_fp4(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
|
||||
a1_gscale=a1_gscale),
|
||||
FlashInferExperts(
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_dtype="nvfp4",
|
||||
))
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -198,6 +198,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
# Note: init_prepare_finalize should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self):
|
||||
assert self.moe is not None
|
||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||
|
||||
@ -388,6 +388,33 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
elif self.allow_flashinfer:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
|
||||
@ -966,22 +966,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if not self.allow_flashinfer:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
if (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.CUTLASS):
|
||||
prepare_finalize = (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
))
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@ -1409,7 +1408,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is None:
|
||||
if self.fused_experts is not None:
|
||||
assert self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
out = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4)
|
||||
|
||||
out = flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
@ -1432,27 +1476,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
e=layer.w13_weight.shape[0],
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
assert self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
out = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user