From f72a817bdf6bd04b223a9da3af6c4ad1a676a98e Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Sun, 30 Nov 2025 18:05:32 -0600 Subject: [PATCH] [MoE] CuteDSL MoE with Nvfp4 DeepEP dispatch (#27141) Signed-off-by: Shu Wang Signed-off-by: Shu Wang. Signed-off-by: Michael Goin Co-authored-by: root Co-authored-by: Michael Goin --- vllm/envs.py | 7 ++ .../fused_moe/deepep_ll_prepare_finalize.py | 83 +++++++++++++------ .../fused_moe/flashinfer_cutedsl_moe.py | 68 ++++++++++----- 3 files changed, 112 insertions(+), 46 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 8ad62e1b8f50..541d5e20d5aa 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -147,6 +147,7 @@ if TYPE_CHECKING: VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None + VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: int | None = None @@ -1127,6 +1128,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MARLIN_INPUT_DTYPE": env_with_choices( "VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"] ), + # Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method + # only supported on Blackwell GPUs and with + # https://github.com/deepseek-ai/DeepEP/pull/341 + "VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool( + int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0")) + ), # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index fea9f49c04b8..06e4a61133bd 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -184,31 +184,47 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): x_fp8, x_scales = x x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) - assert isinstance(x, torch.Tensor) - - num_experts, max_tokens, hidden_dim = x.size() - - # TODO (varun): Optimization - Use a batched version of quant - x = x.view((-1, hidden_dim)) + assert isinstance(x, (torch.Tensor, tuple)) q_dtype = quant_config.quant_dtype - if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": + if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH: logger.info_once( - "Skip quantization when using FlashInfer CUTEDSL(masked_gemm) " - "for ModelOptNvFp4FusedMoE." + "Since VLLM_DEEPEPLL_NVFP4_DISPATCH==1, make sure " + "using the hybrid-ep branch of DeepEP" + "(https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep)" ) - q_dtype = None + assert isinstance(x, tuple) + x_scales = x[1] + x = x[0].permute(2, 0, 1) + num_experts, max_tokens, hidden_dim_by_2 = x.shape + hidden_dim = hidden_dim_by_2 * 2 + assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm" + logger.info_once( + "Quantization is fused with DeepEP nvfp4 dispatch for " + "FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1" + ) + else: + if q_dtype == "nvfp4": + q_dtype = None + logger.info_once( + "Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as " + "VLLM_DEEPEPLL_NVFP4_DISPATCH==0" + ) + assert isinstance(x, torch.Tensor) + num_experts, max_tokens, hidden_dim = x.size() - x, x_scales = moe_kernel_quantize_input( - x, - quant_config.a1_scale, - q_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - x = x.view((num_experts, -1, hidden_dim)) + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + q_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) - if q_dtype is not None: + if q_dtype is not None and q_dtype != "nvfp4": assert x_scales is not None x_scales = normalize_batched_scales_shape(x_scales, num_experts) @@ -240,18 +256,28 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): "DeepEP kernels quantize the inputs in blocks of shape 128" ) + use_nvfp4 = False + nvfp4_dispatch = ( + quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + ) + if nvfp4_dispatch: + use_nvfp4 = True + qc_a1_gscale_or_scale = ( + quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale + ) has_per_token_scales = ( - quant_config.a1_scale.numel() != 1 - if quant_config.a1_scale is not None + qc_a1_gscale_or_scale.numel() != 1 + if qc_a1_gscale_or_scale is not None else ( quant_config.a2_scale.numel() != 1 if quant_config.a2_scale is not None else False ) ) - assert not has_per_token_scales, ( - "low_latency kernels doesn't support dispatching per-token scales" - ) + if not use_nvfp4: + assert not has_per_token_scales, ( + "low_latency kernels doesn't support dispatching per-token scales" + ) if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -269,9 +295,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, - # round_scale needs to be set to dispatch in ue8m0 - round_scale=self.use_ue8m0_dispatch, - use_ue8m0=self.use_ue8m0_dispatch, + **(dict(use_nvfp4=True) if use_nvfp4 else dict()), + **( + dict(x_global_scale=qc_a1_gscale_or_scale) + if qc_a1_gscale_or_scale is not None + else dict() + ), async_finish=False, return_recv_hook=True, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 2747ef04a349..6e0b57156cb3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -109,7 +110,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - output_shape = (local_num_experts, M, K) + K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K + output_shape = (local_num_experts, M, K_dim) workspace2 = (local_num_experts, M, N) workspace1 = output_shape return (workspace1, workspace2, output_shape) @@ -144,9 +146,18 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): assert hidden_states.ndim == 3 assert self.w1_scale.ndim == 3 assert self.w2_scale.ndim == 3 + + input_global_scale = ( + None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale + ) + flashinfer_hidden_states = ( + (hidden_states, a1q_scale) + if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + else hidden_states + ) flashinfer_cutedsl_moe_masked( - hidden_states=hidden_states, - input_global_scale=self.a1_gscale, + hidden_states=flashinfer_hidden_states, + input_global_scale=input_global_scale, w1=w1, w1_blockscale=self.w1_scale, w1_alpha=self.g1_alphas, @@ -172,7 +183,7 @@ def get_cute_dtype(input: torch.Tensor) -> str: def flashinfer_cutedsl_moe_masked( - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], input_global_scale: torch.Tensor, w1: torch.Tensor, w1_blockscale: torch.Tensor, @@ -190,7 +201,10 @@ def flashinfer_cutedsl_moe_masked( kernels. Args: - hidden_states (torch.Tensor): [num_experts, m, k], bf16 + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], + uint8, [num_experts, m, k // 16], float8_e4m3fn input_global_scale (torch.Tensor): (l,) w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 w1_blockscale (torch.Tensor): blockscale factors, e4m3, @@ -207,9 +221,6 @@ def flashinfer_cutedsl_moe_masked( """ # === Assertions on dtypes === - assert input_global_scale.dtype == torch.float32, ( - f"input_global_scale must be float32, got {input_global_scale.dtype}" - ) assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" assert w1_blockscale.dtype == torch.float8_e4m3fn, ( f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" @@ -230,7 +241,32 @@ def flashinfer_cutedsl_moe_masked( # === Assertions on shapes === n = w2.shape[-1] * 2 # intermediate dimension - num_experts, m, k = hidden_states.shape + if isinstance(hidden_states, tuple): + assert input_global_scale is None, ( + "input_global_scale is needed when input needs quant" + ) + + aq = hidden_states[0].view(torch.uint8) + aq_sf = hidden_states[1].view(torch.float8_e4m3fn) + # m, k_by_2, num_experts = aq.shape + num_experts, m, k_by_2 = aq.shape + k = k_by_2 * 2 + aq = aq.permute(1, 2, 0) + else: + num_experts, m, k = hidden_states.shape + + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quantize( + hidden_states, + masked_m, + input_global_scale, + ) assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" assert w1.shape[-1] * 2 == k, ( @@ -241,9 +277,6 @@ def flashinfer_cutedsl_moe_masked( n // 2, ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" - assert input_global_scale.shape == (num_experts,), ( - f"input_global_scale must be (l,), got {input_global_scale.shape}" - ) assert w1_alpha.shape == (num_experts,), ( f"w1_alpha must be (l,), got {w1_alpha.shape}" ) @@ -254,12 +287,6 @@ def flashinfer_cutedsl_moe_masked( f"w2_alpha must be (l,), got {w2_alpha.shape}" ) - aq, aq_sf = scaled_fp4_grouped_quantize( - hidden_states, - masked_m, - input_global_scale, - ) - workspace = workspace.permute(1, 2, 0) # requirement of kernel sf_vec_size = 16 assert aq_sf.dtype == torch.float8_e4m3fn @@ -267,7 +294,10 @@ def flashinfer_cutedsl_moe_masked( ab_dtype = "float4_e2m1fn" sf_dtype = "float8_e4m3fn" - c_dtype = get_cute_dtype(hidden_states) + if isinstance(hidden_states, tuple): + c_dtype = "bfloat16" + else: + c_dtype = get_cute_dtype(hidden_states) # Gemm1 flashinfer_cutedsl_grouped_gemm_nt_masked(