mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 13:25:17 +08:00
Flashinfer_CUTLASS_MOE fuses quantization for TP (#27223)
Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
parent
bc306fe5e9
commit
fc16f1c477
@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
use_dp: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.out_dtype = out_dtype
|
||||
self.use_dp = use_dp
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0,)
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
||||
# For TP, the quantization is fused with fused_moe call.
|
||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
# potential communication op and is involved in the expert computation.
|
||||
return (workspace1, workspace2, output_shape)
|
||||
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
|
||||
FlashInferExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
use_dp=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
if not self.use_dp:
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
if self.use_dp:
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
|
||||
@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
|
||||
@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
|
||||
# native cutlass experts currently don't support DP; TP case won't call this
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user