mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:35:26 +08:00
[MoE] CuteDSL MoE with Nvfp4 DeepEP dispatch (#27141)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: root <root@umbriel-b200-017.ipp4a1.colossus.nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
ec38a7368d
commit
f72a817bdf
@ -147,6 +147,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
|
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
|
||||||
VLLM_MXFP4_USE_MARLIN: bool | None = None
|
VLLM_MXFP4_USE_MARLIN: bool | None = None
|
||||||
|
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
|
||||||
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
VLLM_V1_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
VLLM_TPU_MOST_MODEL_LEN: int | None = None
|
VLLM_TPU_MOST_MODEL_LEN: int | None = None
|
||||||
@ -1127,6 +1128,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
|
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
|
||||||
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
|
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
|
||||||
),
|
),
|
||||||
|
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
|
||||||
|
# only supported on Blackwell GPUs and with
|
||||||
|
# https://github.com/deepseek-ai/DeepEP/pull/341
|
||||||
|
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
|
||||||
|
),
|
||||||
# Whether to turn on the outlines cache for V1
|
# Whether to turn on the outlines cache for V1
|
||||||
# This cache is unbounded and on disk, so it's not safe to use in
|
# This cache is unbounded and on disk, so it's not safe to use in
|
||||||
# an environment with potentially malicious users.
|
# an environment with potentially malicious users.
|
||||||
|
|||||||
@ -184,31 +184,47 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
x_fp8, x_scales = x
|
x_fp8, x_scales = x
|
||||||
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
|
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
|
||||||
|
|
||||||
assert isinstance(x, torch.Tensor)
|
assert isinstance(x, (torch.Tensor, tuple))
|
||||||
|
|
||||||
num_experts, max_tokens, hidden_dim = x.size()
|
|
||||||
|
|
||||||
# TODO (varun): Optimization - Use a batched version of quant
|
|
||||||
x = x.view((-1, hidden_dim))
|
|
||||||
q_dtype = quant_config.quant_dtype
|
q_dtype = quant_config.quant_dtype
|
||||||
|
|
||||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
|
"Since VLLM_DEEPEPLL_NVFP4_DISPATCH==1, make sure "
|
||||||
"for ModelOptNvFp4FusedMoE."
|
"using the hybrid-ep branch of DeepEP"
|
||||||
|
"(https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep)"
|
||||||
)
|
)
|
||||||
q_dtype = None
|
assert isinstance(x, tuple)
|
||||||
|
x_scales = x[1]
|
||||||
|
x = x[0].permute(2, 0, 1)
|
||||||
|
num_experts, max_tokens, hidden_dim_by_2 = x.shape
|
||||||
|
hidden_dim = hidden_dim_by_2 * 2
|
||||||
|
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
|
||||||
|
logger.info_once(
|
||||||
|
"Quantization is fused with DeepEP nvfp4 dispatch for "
|
||||||
|
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if q_dtype == "nvfp4":
|
||||||
|
q_dtype = None
|
||||||
|
logger.info_once(
|
||||||
|
"Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as "
|
||||||
|
"VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
|
||||||
|
)
|
||||||
|
assert isinstance(x, torch.Tensor)
|
||||||
|
num_experts, max_tokens, hidden_dim = x.size()
|
||||||
|
|
||||||
x, x_scales = moe_kernel_quantize_input(
|
# TODO (varun): Optimization - Use a batched version of quant
|
||||||
x,
|
x = x.view((-1, hidden_dim))
|
||||||
quant_config.a1_scale,
|
x, x_scales = moe_kernel_quantize_input(
|
||||||
q_dtype,
|
x,
|
||||||
quant_config.per_act_token_quant,
|
quant_config.a1_scale,
|
||||||
quant_config.block_shape,
|
q_dtype,
|
||||||
)
|
quant_config.per_act_token_quant,
|
||||||
x = x.view((num_experts, -1, hidden_dim))
|
quant_config.block_shape,
|
||||||
|
)
|
||||||
|
x = x.view((num_experts, -1, hidden_dim))
|
||||||
|
|
||||||
if q_dtype is not None:
|
if q_dtype is not None and q_dtype != "nvfp4":
|
||||||
assert x_scales is not None
|
assert x_scales is not None
|
||||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||||
|
|
||||||
@ -240,18 +256,28 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_nvfp4 = False
|
||||||
|
nvfp4_dispatch = (
|
||||||
|
quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
|
||||||
|
)
|
||||||
|
if nvfp4_dispatch:
|
||||||
|
use_nvfp4 = True
|
||||||
|
qc_a1_gscale_or_scale = (
|
||||||
|
quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale
|
||||||
|
)
|
||||||
has_per_token_scales = (
|
has_per_token_scales = (
|
||||||
quant_config.a1_scale.numel() != 1
|
qc_a1_gscale_or_scale.numel() != 1
|
||||||
if quant_config.a1_scale is not None
|
if qc_a1_gscale_or_scale is not None
|
||||||
else (
|
else (
|
||||||
quant_config.a2_scale.numel() != 1
|
quant_config.a2_scale.numel() != 1
|
||||||
if quant_config.a2_scale is not None
|
if quant_config.a2_scale is not None
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert not has_per_token_scales, (
|
if not use_nvfp4:
|
||||||
"low_latency kernels doesn't support dispatching per-token scales"
|
assert not has_per_token_scales, (
|
||||||
)
|
"low_latency kernels doesn't support dispatching per-token scales"
|
||||||
|
)
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
@ -269,9 +295,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
use_fp8=self.use_fp8_dispatch,
|
use_fp8=self.use_fp8_dispatch,
|
||||||
# round_scale needs to be set to dispatch in ue8m0
|
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||||
round_scale=self.use_ue8m0_dispatch,
|
**(
|
||||||
use_ue8m0=self.use_ue8m0_dispatch,
|
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||||
|
if qc_a1_gscale_or_scale is not None
|
||||||
|
else dict()
|
||||||
|
),
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
@ -109,7 +110,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
- Note: in order for activation chunking to work, the first dimension
|
- Note: in order for activation chunking to work, the first dimension
|
||||||
of each tuple must be the number of tokens.
|
of each tuple must be the number of tokens.
|
||||||
"""
|
"""
|
||||||
output_shape = (local_num_experts, M, K)
|
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
|
||||||
|
output_shape = (local_num_experts, M, K_dim)
|
||||||
workspace2 = (local_num_experts, M, N)
|
workspace2 = (local_num_experts, M, N)
|
||||||
workspace1 = output_shape
|
workspace1 = output_shape
|
||||||
return (workspace1, workspace2, output_shape)
|
return (workspace1, workspace2, output_shape)
|
||||||
@ -144,9 +146,18 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert hidden_states.ndim == 3
|
assert hidden_states.ndim == 3
|
||||||
assert self.w1_scale.ndim == 3
|
assert self.w1_scale.ndim == 3
|
||||||
assert self.w2_scale.ndim == 3
|
assert self.w2_scale.ndim == 3
|
||||||
|
|
||||||
|
input_global_scale = (
|
||||||
|
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
|
||||||
|
)
|
||||||
|
flashinfer_hidden_states = (
|
||||||
|
(hidden_states, a1q_scale)
|
||||||
|
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
|
||||||
|
else hidden_states
|
||||||
|
)
|
||||||
flashinfer_cutedsl_moe_masked(
|
flashinfer_cutedsl_moe_masked(
|
||||||
hidden_states=hidden_states,
|
hidden_states=flashinfer_hidden_states,
|
||||||
input_global_scale=self.a1_gscale,
|
input_global_scale=input_global_scale,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w1_blockscale=self.w1_scale,
|
w1_blockscale=self.w1_scale,
|
||||||
w1_alpha=self.g1_alphas,
|
w1_alpha=self.g1_alphas,
|
||||||
@ -172,7 +183,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def flashinfer_cutedsl_moe_masked(
|
def flashinfer_cutedsl_moe_masked(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||||
input_global_scale: torch.Tensor,
|
input_global_scale: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1_blockscale: torch.Tensor,
|
w1_blockscale: torch.Tensor,
|
||||||
@ -190,7 +201,10 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
kernels.
|
kernels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
hidden_states: Either of the following case
|
||||||
|
* torch.Tensor: [num_experts, m, k], bf16
|
||||||
|
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
|
||||||
|
uint8, [num_experts, m, k // 16], float8_e4m3fn
|
||||||
input_global_scale (torch.Tensor): (l,)
|
input_global_scale (torch.Tensor): (l,)
|
||||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||||
@ -207,9 +221,6 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# === Assertions on dtypes ===
|
# === Assertions on dtypes ===
|
||||||
assert input_global_scale.dtype == torch.float32, (
|
|
||||||
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
|
||||||
)
|
|
||||||
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
||||||
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
||||||
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||||
@ -230,7 +241,32 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
|
|
||||||
# === Assertions on shapes ===
|
# === Assertions on shapes ===
|
||||||
n = w2.shape[-1] * 2 # intermediate dimension
|
n = w2.shape[-1] * 2 # intermediate dimension
|
||||||
num_experts, m, k = hidden_states.shape
|
if isinstance(hidden_states, tuple):
|
||||||
|
assert input_global_scale is None, (
|
||||||
|
"input_global_scale is needed when input needs quant"
|
||||||
|
)
|
||||||
|
|
||||||
|
aq = hidden_states[0].view(torch.uint8)
|
||||||
|
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
|
||||||
|
# m, k_by_2, num_experts = aq.shape
|
||||||
|
num_experts, m, k_by_2 = aq.shape
|
||||||
|
k = k_by_2 * 2
|
||||||
|
aq = aq.permute(1, 2, 0)
|
||||||
|
else:
|
||||||
|
num_experts, m, k = hidden_states.shape
|
||||||
|
|
||||||
|
assert input_global_scale.dtype == torch.float32, (
|
||||||
|
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||||
|
)
|
||||||
|
assert input_global_scale.shape == (num_experts,), (
|
||||||
|
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||||
|
hidden_states,
|
||||||
|
masked_m,
|
||||||
|
input_global_scale,
|
||||||
|
)
|
||||||
|
|
||||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||||
assert w1.shape[-1] * 2 == k, (
|
assert w1.shape[-1] * 2 == k, (
|
||||||
@ -241,9 +277,6 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
n // 2,
|
n // 2,
|
||||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
||||||
|
|
||||||
assert input_global_scale.shape == (num_experts,), (
|
|
||||||
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
|
||||||
)
|
|
||||||
assert w1_alpha.shape == (num_experts,), (
|
assert w1_alpha.shape == (num_experts,), (
|
||||||
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||||
)
|
)
|
||||||
@ -254,12 +287,6 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
|
||||||
hidden_states,
|
|
||||||
masked_m,
|
|
||||||
input_global_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
||||||
sf_vec_size = 16
|
sf_vec_size = 16
|
||||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||||
@ -267,7 +294,10 @@ def flashinfer_cutedsl_moe_masked(
|
|||||||
ab_dtype = "float4_e2m1fn"
|
ab_dtype = "float4_e2m1fn"
|
||||||
sf_dtype = "float8_e4m3fn"
|
sf_dtype = "float8_e4m3fn"
|
||||||
|
|
||||||
c_dtype = get_cute_dtype(hidden_states)
|
if isinstance(hidden_states, tuple):
|
||||||
|
c_dtype = "bfloat16"
|
||||||
|
else:
|
||||||
|
c_dtype = get_cute_dtype(hidden_states)
|
||||||
|
|
||||||
# Gemm1
|
# Gemm1
|
||||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user