mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 06:17:05 +08:00
Flashinfer_CUTLASS_MOE fuses quantization for TP (#27223)
Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
parent
bc306fe5e9
commit
fc16f1c477
@ -56,6 +56,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
|
use_dp: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||||
@ -67,6 +68,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
|
self.use_dp = use_dp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -117,7 +119,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"""
|
"""
|
||||||
workspace1 = (M, K)
|
workspace1 = (M, K)
|
||||||
workspace2 = (0,)
|
workspace2 = (0,)
|
||||||
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
|
# For TP, the quantization is fused with fused_moe call.
|
||||||
|
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
|
||||||
# The workspace is determined by `aq`, since it comes after any
|
# The workspace is determined by `aq`, since it comes after any
|
||||||
# potential communication op and is involved in the expert computation.
|
# potential communication op and is involved in the expert computation.
|
||||||
return (workspace1, workspace2, output_shape)
|
return (workspace1, workspace2, output_shape)
|
||||||
@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
|
|||||||
FlashInferExperts(
|
FlashInferExperts(
|
||||||
out_dtype=hidden_states.dtype,
|
out_dtype=hidden_states.dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
use_dp=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -170,6 +170,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
self._apply_router_weight_on_input(
|
self._apply_router_weight_on_input(
|
||||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||||
)
|
)
|
||||||
|
if not self.use_dp:
|
||||||
|
return a1, None, None, topk_ids, topk_weights
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
@ -179,14 +181,13 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
quant_config.block_shape,
|
quant_config.block_shape,
|
||||||
is_fp4_scale_swizzled=not self.use_dp,
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
)
|
)
|
||||||
if self.use_dp:
|
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
dim=0,
|
||||||
dim=0,
|
sizes=get_local_sizes(),
|
||||||
sizes=get_local_sizes(),
|
)
|
||||||
)
|
if quant_config.quant_dtype == "nvfp4":
|
||||||
if quant_config.quant_dtype == "nvfp4":
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
|
||||||
|
|
||||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||||
|
|
||||||
|
|||||||
@ -1769,29 +1769,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
self.allow_flashinfer
|
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
|
||||||
):
|
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
||||||
flashinfer_cutlass_moe_fp4,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.moe_quant_config is not None
|
|
||||||
|
|
||||||
return flashinfer_cutlass_moe_fp4(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
inplace=False,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||||
# only (no EP).
|
# only (no EP).
|
||||||
|
|||||||
@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
|
|||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
tp_size=moe.moe_parallel_config.tp_size,
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# native cutlass experts currently don't support DP; TP case won't call this
|
# native cutlass experts currently don't support DP; TP case won't call this
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user