[MoE] CuteDSL MoE with Nvfp4 DeepEP dispatch (#27141)

Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: root <root@umbriel-b200-017.ipp4a1.colossus.nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang 2025-11-30 18:05:32 -06:00 committed by GitHub
parent ec38a7368d
commit f72a817bdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 112 additions and 46 deletions

View File

@ -147,6 +147,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: int | None = None
@ -1127,6 +1128,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
# https://github.com/deepseek-ai/DeepEP/pull/341
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
),
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.

View File

@ -184,31 +184,47 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x_fp8, x_scales = x
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
assert isinstance(x, torch.Tensor)
num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
assert isinstance(x, (torch.Tensor, tuple))
q_dtype = quant_config.quant_dtype
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
logger.info_once(
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
"for ModelOptNvFp4FusedMoE."
"Since VLLM_DEEPEPLL_NVFP4_DISPATCH==1, make sure "
"using the hybrid-ep branch of DeepEP"
"(https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep)"
)
q_dtype = None
assert isinstance(x, tuple)
x_scales = x[1]
x = x[0].permute(2, 0, 1)
num_experts, max_tokens, hidden_dim_by_2 = x.shape
hidden_dim = hidden_dim_by_2 * 2
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
logger.info_once(
"Quantization is fused with DeepEP nvfp4 dispatch for "
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
)
else:
if q_dtype == "nvfp4":
q_dtype = None
logger.info_once(
"Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as "
"VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
)
assert isinstance(x, torch.Tensor)
num_experts, max_tokens, hidden_dim = x.size()
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
if q_dtype is not None:
if q_dtype is not None and q_dtype != "nvfp4":
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
@ -240,18 +256,28 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"DeepEP kernels quantize the inputs in blocks of shape 128"
)
use_nvfp4 = False
nvfp4_dispatch = (
quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
)
if nvfp4_dispatch:
use_nvfp4 = True
qc_a1_gscale_or_scale = (
quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale
)
has_per_token_scales = (
quant_config.a1_scale.numel() != 1
if quant_config.a1_scale is not None
qc_a1_gscale_or_scale.numel() != 1
if qc_a1_gscale_or_scale is not None
else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None
else False
)
)
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales"
)
if not use_nvfp4:
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales"
)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@ -269,9 +295,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
# round_scale needs to be set to dispatch in ue8m0
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)

View File

@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@ -109,7 +110,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
output_shape = (local_num_experts, M, K)
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
output_shape = (local_num_experts, M, K_dim)
workspace2 = (local_num_experts, M, N)
workspace1 = output_shape
return (workspace1, workspace2, output_shape)
@ -144,9 +146,18 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert hidden_states.ndim == 3
assert self.w1_scale.ndim == 3
assert self.w2_scale.ndim == 3
input_global_scale = (
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
)
flashinfer_hidden_states = (
(hidden_states, a1q_scale)
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
else hidden_states
)
flashinfer_cutedsl_moe_masked(
hidden_states=hidden_states,
input_global_scale=self.a1_gscale,
hidden_states=flashinfer_hidden_states,
input_global_scale=input_global_scale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
@ -172,7 +183,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
@ -190,7 +201,10 @@ def flashinfer_cutedsl_moe_masked(
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
@ -207,9 +221,6 @@ def flashinfer_cutedsl_moe_masked(
"""
# === Assertions on dtypes ===
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
@ -230,7 +241,32 @@ def flashinfer_cutedsl_moe_masked(
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape
if isinstance(hidden_states, tuple):
assert input_global_scale is None, (
"input_global_scale is needed when input needs quant"
)
aq = hidden_states[0].view(torch.uint8)
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
# m, k_by_2, num_experts = aq.shape
num_experts, m, k_by_2 = aq.shape
k = k_by_2 * 2
aq = aq.permute(1, 2, 0)
else:
num_experts, m, k = hidden_states.shape
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert w1.shape[-1] * 2 == k, (
@ -241,9 +277,6 @@ def flashinfer_cutedsl_moe_masked(
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
assert w1_alpha.shape == (num_experts,), (
f"w1_alpha must be (l,), got {w1_alpha.shape}"
)
@ -254,12 +287,6 @@ def flashinfer_cutedsl_moe_masked(
f"w2_alpha must be (l,), got {w2_alpha.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
workspace = workspace.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
@ -267,7 +294,10 @@ def flashinfer_cutedsl_moe_masked(
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = get_cute_dtype(hidden_states)
if isinstance(hidden_states, tuple):
c_dtype = "bfloat16"
else:
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked(