mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 05:55:15 +08:00
feat: BF16 FlashInfer Fused Cutlass MOE for Hopper and Blackwell Expert Parallel (#25503)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b4a80dad98
commit
461aa1463b
@ -144,6 +144,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput",
|
||||
@ -1145,6 +1146,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))),
|
||||
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))),
|
||||
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
|
||||
@ -1516,6 +1521,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP16",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8",
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4",
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
|
||||
|
||||
@ -52,8 +52,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
tp_size: int = 1,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||
"Only nvfp4,fp8 quantization are currently supported.")
|
||||
assert quant_config.quant_dtype in (
|
||||
"nvfp4", torch.float8_e4m3fn,
|
||||
None), ("Only nvfp4, fp8, bfloat16 and"
|
||||
" float16 quantization are currently supported.")
|
||||
self.ep_rank = ep_rank
|
||||
self.ep_size = ep_size
|
||||
self.tp_rank = tp_rank
|
||||
@ -109,8 +111,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
aq_m, aq_n = aq.shape
|
||||
workspace2 = (0, )
|
||||
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
|
||||
torch.float8_e4m3fn else (aq_m, aq_n)
|
||||
output_shape = (aq_m,
|
||||
aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m,
|
||||
aq_n)
|
||||
workspace_dtype = a.dtype
|
||||
workspace1 = output_shape
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
@ -135,6 +138,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: Optional[bool],
|
||||
):
|
||||
|
||||
assert activation == "silu", ("Only activation silu is supported in "
|
||||
"FlashInferExperts")
|
||||
|
||||
if self.quant_dtype == torch.float8_e4m3fn:
|
||||
quant_scales = [
|
||||
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
|
||||
@ -143,7 +150,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a1q_scale = None # not passing input_sf in fp8
|
||||
fc1_expert_weights = w1
|
||||
fc2_expert_weights = w2
|
||||
else:
|
||||
elif self.quant_dtype == "nvfp4":
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not "
|
||||
@ -161,6 +168,11 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
fc1_expert_weights = w1.view(torch.long)
|
||||
fc2_expert_weights = w2.view(torch.long)
|
||||
else:
|
||||
quant_scales = None
|
||||
a1q_scale = None
|
||||
fc1_expert_weights = w1
|
||||
fc2_expert_weights = w2
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
@ -211,3 +223,46 @@ def flashinfer_cutlass_moe_fp4(
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_cutlass_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
ep_size: int = 1,
|
||||
use_dp: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=use_dp),
|
||||
FlashInferExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
ep_rank=ep_rank,
|
||||
ep_size=ep_size,
|
||||
))
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -183,7 +183,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -296,6 +297,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
self.flashinfer_cutlass_moe_enabled = (
|
||||
has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and self.moe.moe_parallel_config.use_ep
|
||||
and self.moe.moe_parallel_config.dp_size == 1
|
||||
and current_platform.get_device_capability()[0] >= 9)
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
|
||||
self.flashinfer_cutlass_moe = partial(
|
||||
flashinfer_cutlass_moe,
|
||||
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
tp_rank=self.moe.moe_parallel_config.tp_rank,
|
||||
tp_size=self.moe.moe_parallel_config.tp_size,
|
||||
ep_rank=self.moe.moe_parallel_config.ep_rank,
|
||||
ep_size=self.moe.moe_parallel_config.ep_size)
|
||||
else:
|
||||
if (self.moe.moe_parallel_config.use_ep
|
||||
and self.moe.moe_parallel_config.dp_size == 1):
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is available for EP"
|
||||
" but not enabled, consider setting"
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.")
|
||||
elif self.moe.moe_parallel_config.dp_size > 1:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP."
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -367,6 +402,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@ -386,6 +422,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||
|
||||
if current_platform.is_xpu():
|
||||
import intel_extension_for_pytorch as ipex
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
@ -536,6 +578,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.flashinfer_cutlass_moe_enabled:
|
||||
return self.flashinfer_cutlass_moe(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.fused_experts is not None:
|
||||
if self.moe.has_bias:
|
||||
raise ValueError(
|
||||
|
||||
@ -598,6 +598,8 @@ class SharedResizableBuffer:
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||
dtype: torch.dtype):
|
||||
if shape == () or shape is None:
|
||||
return None
|
||||
shape_numel = prod(shape)
|
||||
if (self.buffer is None or self.buffer.numel() < shape_numel
|
||||
or self.buffer.device != device or self.buffer.dtype != dtype):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user