Flashinfer_CUTLASS_MOE fuses quantization for TP (#27223)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-10-31 10:54:29 -07:00 committed by GitHub
parent bc306fe5e9
commit fc16f1c477
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 32 deletions

View File

@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size: int = 1, ep_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
tp_size: int = 1, tp_size: int = 1,
use_dp: bool = False,
): ):
super().__init__(quant_config) super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.use_dp = use_dp
@property @property
def activation_formats( def activation_formats(
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
workspace1 = (M, K) workspace1 = (M, K)
workspace2 = (0,) workspace2 = (0,)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any # The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation. # potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape) return (workspace1, workspace2, output_shape)
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
FlashInferExperts( FlashInferExperts(
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
quant_config=quant_config, quant_config=quant_config,
use_dp=False,
), ),
) )

View File

@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input( self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input a1, topk_weights, topk_ids, apply_router_weight_on_input
) )
if not self.use_dp:
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
quant_config.block_shape, quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp, is_fp4_scale_swizzled=not self.use_dp,
) )
if self.use_dp: topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( [topk_weights, topk_ids, a1q, a1q_scale],
[topk_weights, topk_ids, a1q, a1q_scale], dim=0,
dim=0, sizes=get_local_sizes(),
sizes=get_local_sizes(), )
) if quant_config.quant_dtype == "nvfp4":
if quant_config.quant_dtype == "nvfp4": a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else: else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP). # only (no EP).

View File

@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size, ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank, tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size, tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
) )
# native cutlass experts currently don't support DP; TP case won't call this # native cutlass experts currently don't support DP; TP case won't call this