Flashinfer_CUTLASS_MOE fuses quantization for TP (#27223)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-10-31 10:54:29 -07:00 committed by GitHub
parent bc306fe5e9
commit fc16f1c477
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 32 deletions

View File

@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
):
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.tp_rank = tp_rank
self.tp_size = tp_size
self.out_dtype = out_dtype
self.use_dp = use_dp
@property
def activation_formats(
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
workspace1 = (M, K)
workspace2 = (0,)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape)
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
use_dp=False,
),
)

View File

@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp:
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).

View File

@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
# native cutlass experts currently don't support DP; TP case won't call this