mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 01:15:43 +08:00
[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
511a6b611d
commit
3f8a874065
@ -57,6 +57,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
use_dp: bool = False,
|
use_dp: bool = False,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||||
@ -69,6 +70,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.use_dp = use_dp
|
self.use_dp = use_dp
|
||||||
|
# Enables DeepSeek-style FP8 block-scale path:
|
||||||
|
# - pass per-block weight scales to the kernel
|
||||||
|
# - skip input activation quantization (kernel applies scaling)
|
||||||
|
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -147,7 +152,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"Only activation silu is supported in FlashInferExperts"
|
"Only activation silu is supported in FlashInferExperts"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.quant_dtype == torch.float8_e4m3fn:
|
# Select quantization metadata based on FP8 format/path
|
||||||
|
if (
|
||||||
|
self.quant_dtype == torch.float8_e4m3fn
|
||||||
|
and not self.use_deepseek_fp8_block_scale
|
||||||
|
):
|
||||||
|
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
|
||||||
quant_scales = [
|
quant_scales = [
|
||||||
self.g1_alphas,
|
self.g1_alphas,
|
||||||
self.a2_gscale,
|
self.a2_gscale,
|
||||||
@ -176,6 +186,15 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# FlashInfer API requires weight to be long for nvfp4
|
# FlashInfer API requires weight to be long for nvfp4
|
||||||
fc1_expert_weights = w1.view(torch.long)
|
fc1_expert_weights = w1.view(torch.long)
|
||||||
fc2_expert_weights = w2.view(torch.long)
|
fc2_expert_weights = w2.view(torch.long)
|
||||||
|
elif self.use_deepseek_fp8_block_scale:
|
||||||
|
# FP8 block-scale path: provide block-scale weights, omit a1q_scale
|
||||||
|
quant_scales = [
|
||||||
|
self.w1_scale,
|
||||||
|
self.w2_scale,
|
||||||
|
]
|
||||||
|
a1q_scale = None
|
||||||
|
fc1_expert_weights = w1
|
||||||
|
fc2_expert_weights = w2
|
||||||
else:
|
else:
|
||||||
quant_scales = None
|
quant_scales = None
|
||||||
a1q_scale = None
|
a1q_scale = None
|
||||||
@ -196,6 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
ep_size=self.ep_size,
|
ep_size=self.ep_size,
|
||||||
ep_rank=self.ep_rank,
|
ep_rank=self.ep_rank,
|
||||||
output=output,
|
output=output,
|
||||||
|
# Informs FlashInfer to use the block-scale decoding path when True
|
||||||
|
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,11 +28,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
self,
|
self,
|
||||||
use_dp: bool,
|
use_dp: bool,
|
||||||
num_dispatchers: int = 1,
|
num_dispatchers: int = 1,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_dispatchers_ = num_dispatchers
|
self.num_dispatchers_ = num_dispatchers
|
||||||
self.use_dp = use_dp
|
self.use_dp = use_dp
|
||||||
self.local_tokens = None
|
self.local_tokens = None
|
||||||
|
# Toggle for DeepSeek-style FP8 block-scale path where activations are
|
||||||
|
# not quantized here and weight block scales are consumed by the kernel.
|
||||||
|
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
@ -73,8 +77,9 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
|||||||
self,
|
self,
|
||||||
use_dp: bool,
|
use_dp: bool,
|
||||||
num_dispatchers: int = 1,
|
num_dispatchers: int = 1,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(use_dp, num_dispatchers)
|
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||||
self.alltoall_info = None
|
self.alltoall_info = None
|
||||||
|
|
||||||
# Initialize all2all_manager only for DP case
|
# Initialize all2all_manager only for DP case
|
||||||
@ -97,15 +102,19 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.use_dp:
|
if not self.use_dp:
|
||||||
# Non-DP case: standard quantization
|
# Non-DP case: quantize activations unless using block-scale path
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
if not self.use_deepseek_fp8_block_scale:
|
||||||
a1,
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
quant_config.a1_gscale,
|
a1,
|
||||||
quant_config.quant_dtype,
|
quant_config.a1_gscale,
|
||||||
quant_config.per_act_token_quant,
|
quant_config.quant_dtype,
|
||||||
quant_config.block_shape,
|
quant_config.per_act_token_quant,
|
||||||
is_fp4_scale_swizzled=not self.use_dp,
|
quant_config.block_shape,
|
||||||
)
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
a1q = a1
|
||||||
|
a1q_scale = None
|
||||||
else:
|
else:
|
||||||
# DP case: use FlashInfer AllToAll
|
# DP case: use FlashInfer AllToAll
|
||||||
global_num_tokens_cpu = get_local_sizes()
|
global_num_tokens_cpu = get_local_sizes()
|
||||||
@ -122,6 +131,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
|||||||
top_k,
|
top_k,
|
||||||
num_experts,
|
num_experts,
|
||||||
quant_config,
|
quant_config,
|
||||||
|
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -154,8 +164,9 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
self,
|
self,
|
||||||
use_dp: bool,
|
use_dp: bool,
|
||||||
num_dispatchers: int = 1,
|
num_dispatchers: int = 1,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(use_dp, num_dispatchers)
|
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
@ -173,22 +184,42 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
||||||
return a1, None, None, topk_ids, topk_weights
|
return a1, None, None, topk_ids, topk_weights
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
if not self.use_deepseek_fp8_block_scale:
|
||||||
a1,
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
quant_config.a1_gscale,
|
a1,
|
||||||
quant_config.quant_dtype,
|
quant_config.a1_gscale,
|
||||||
quant_config.per_act_token_quant,
|
quant_config.quant_dtype,
|
||||||
quant_config.block_shape,
|
quant_config.per_act_token_quant,
|
||||||
is_fp4_scale_swizzled=not self.use_dp,
|
quant_config.block_shape,
|
||||||
)
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Block-scale path: pass activations through, omit per-token scales
|
||||||
|
a1q = a1
|
||||||
|
a1q_scale = None
|
||||||
|
|
||||||
if self.use_dp:
|
if self.use_dp:
|
||||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
# Build gather list conditionally - omit a1q_scale if None
|
||||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
# (block-scale path)
|
||||||
dim=0,
|
gather_list = [topk_weights, topk_ids, a1q]
|
||||||
sizes=get_local_sizes(),
|
if a1q_scale is not None:
|
||||||
)
|
gather_list.append(a1q_scale)
|
||||||
if quant_config.quant_dtype == "nvfp4":
|
gathered = get_dp_group().all_gatherv(
|
||||||
|
gather_list,
|
||||||
|
dim=0,
|
||||||
|
sizes=get_local_sizes(),
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids, a1q, a1q_scale = gathered
|
||||||
|
else:
|
||||||
|
gathered = get_dp_group().all_gatherv(
|
||||||
|
gather_list,
|
||||||
|
dim=0,
|
||||||
|
sizes=get_local_sizes(),
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids, a1q = gathered
|
||||||
|
a1q_scale = None
|
||||||
|
|
||||||
|
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
|
|
||||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||||
@ -221,6 +252,7 @@ def flashinfer_alltoall_dispatch(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
):
|
):
|
||||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||||
|
|
||||||
@ -250,30 +282,42 @@ def flashinfer_alltoall_dispatch(
|
|||||||
)
|
)
|
||||||
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
|
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
|
||||||
|
|
||||||
x, x_sf = moe_kernel_quantize_input(
|
if not use_deepseek_fp8_block_scale:
|
||||||
x,
|
x, x_sf = moe_kernel_quantize_input(
|
||||||
gs,
|
x,
|
||||||
quant_config.quant_dtype,
|
gs,
|
||||||
quant_config.per_act_token_quant,
|
quant_config.quant_dtype,
|
||||||
quant_config.block_shape,
|
quant_config.per_act_token_quant,
|
||||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
quant_config.block_shape,
|
||||||
)
|
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
)
|
||||||
x,
|
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||||
alltoall_info,
|
x,
|
||||||
all2all_manager.workspace_tensor,
|
alltoall_info,
|
||||||
ep_rank,
|
all2all_manager.workspace_tensor,
|
||||||
ep_size,
|
ep_rank,
|
||||||
)
|
ep_size,
|
||||||
|
)
|
||||||
|
|
||||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||||
x_sf,
|
x_sf,
|
||||||
alltoall_info,
|
alltoall_info,
|
||||||
all2all_manager.workspace_tensor,
|
all2all_manager.workspace_tensor,
|
||||||
ep_rank,
|
ep_rank,
|
||||||
ep_size,
|
ep_size,
|
||||||
)
|
)
|
||||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
if quant_config.quant_dtype == "nvfp4":
|
||||||
|
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||||
|
else:
|
||||||
|
# Block-scale path: pass activations through without quantization
|
||||||
|
x_sf = None
|
||||||
|
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||||
|
x,
|
||||||
|
alltoall_info,
|
||||||
|
all2all_manager.workspace_tensor,
|
||||||
|
ep_rank,
|
||||||
|
ep_size,
|
||||||
|
)
|
||||||
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
||||||
|
|
||||||
|
|
||||||
@ -304,6 +348,7 @@ def create_flashinfer_prepare_finalize(
|
|||||||
use_dp: bool,
|
use_dp: bool,
|
||||||
use_nvfp4: bool = False,
|
use_nvfp4: bool = False,
|
||||||
enable_alltoallv: bool = False,
|
enable_alltoallv: bool = False,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
) -> FlashInferCutlassMoEPrepareAndFinalize:
|
) -> FlashInferCutlassMoEPrepareAndFinalize:
|
||||||
"""Factory function to create the appropriate FlashInfer implementation."""
|
"""Factory function to create the appropriate FlashInfer implementation."""
|
||||||
if use_nvfp4:
|
if use_nvfp4:
|
||||||
@ -311,5 +356,7 @@ def create_flashinfer_prepare_finalize(
|
|||||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||||
else:
|
else:
|
||||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||||
# Fp8 only supports AllGather
|
# FP8 path currently supported via AllGather; optionally enable block-scale
|
||||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
return FlashInferAllGatherMoEPrepareAndFinalize(
|
||||||
|
use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||||
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -122,10 +123,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
|||||||
Select the primary FP8 MoE backend
|
Select the primary FP8 MoE backend
|
||||||
Note: Shape-specific fallbacks may still occur at runtime.
|
Note: Shape-specific fallbacks may still occur at runtime.
|
||||||
"""
|
"""
|
||||||
# prefer FlashInfer backends when available and enabled on supported GPUs
|
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||||
if (
|
if (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(100)
|
and (
|
||||||
|
current_platform.is_device_capability(100)
|
||||||
|
or current_platform.is_device_capability(90)
|
||||||
|
)
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||||
and has_flashinfer_moe()
|
and has_flashinfer_moe()
|
||||||
):
|
):
|
||||||
@ -134,14 +138,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
|||||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||||
else:
|
else:
|
||||||
if block_quant:
|
if block_quant and current_platform.is_device_capability(100):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashInfer FP8 MoE throughput backend does not "
|
"FlashInfer FP8 MoE throughput backend does not "
|
||||||
"support block quantization. Please use "
|
"support block quantization. Please use "
|
||||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||||
"instead."
|
"instead."
|
||||||
)
|
)
|
||||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
|
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||||
|
|
||||||
# weight-only path for older GPUs without native FP8
|
# weight-only path for older GPUs without native FP8
|
||||||
@ -641,6 +645,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||||
|
if self.block_quant:
|
||||||
|
assert self.weight_block_size == [128, 128], (
|
||||||
|
f"Only support weight_block_size == [128, 128], "
|
||||||
|
f"got {self.weight_block_size}"
|
||||||
|
)
|
||||||
|
self.flashinfer_moe_fn = partial(
|
||||||
|
flashinfer_cutlass_moe_fp8,
|
||||||
|
moe=self.moe,
|
||||||
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
|
)
|
||||||
|
|
||||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||||
@ -1012,8 +1026,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
if self.block_quant:
|
||||||
|
assert self.weight_block_size == [128, 128], (
|
||||||
|
f"Only support weight_block_size == [128, 128], "
|
||||||
|
f"got {self.weight_block_size}"
|
||||||
|
)
|
||||||
|
# Wire block-scale flag through prepare/finalize when using CUTLASS
|
||||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
self.moe
|
self.moe,
|
||||||
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
)
|
)
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
@ -1062,9 +1083,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
# Select GEMM experts with block-scale when weights are block-quantized
|
||||||
experts = select_cutlass_fp8_gemm_impl(
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
self.moe,
|
self.moe,
|
||||||
self.moe_quant_config,
|
self.moe_quant_config,
|
||||||
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
)
|
)
|
||||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
return experts
|
return experts
|
||||||
@ -1251,16 +1274,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
)
|
)
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
assert not self.block_quant
|
|
||||||
assert not renormalize and custom_routing_function is not None
|
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"Expected 'silu' activation but got {activation}"
|
f"Expected 'silu' activation but got {activation}"
|
||||||
)
|
)
|
||||||
assert scoring_func == "sigmoid", (
|
if not self.block_quant:
|
||||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
assert not renormalize and custom_routing_function is not None
|
||||||
)
|
assert scoring_func == "sigmoid", (
|
||||||
|
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||||
result = flashinfer_cutlass_moe_fp8(
|
)
|
||||||
|
# Delegate to CUTLASS FlashInfer path; function already bound with
|
||||||
|
# use_deepseek_fp8_block_scale for block-quant when applicable
|
||||||
|
result = self.flashinfer_moe_fn(
|
||||||
x,
|
x,
|
||||||
layer,
|
layer,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
|||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
create_flashinfer_prepare_finalize,
|
create_flashinfer_prepare_finalize,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
moe: FusedMoEConfig | None,
|
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
|
||||||
) -> mk.FusedMoEPrepareAndFinalize:
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||||
return create_flashinfer_prepare_finalize(use_dp)
|
# Propagate block-scale flag so prepare/finalize can skip act quantization
|
||||||
|
# and inform the kernel to consume per-block weight scales.
|
||||||
|
return create_flashinfer_prepare_finalize(
|
||||||
|
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_cutlass_fp8_gemm_impl(
|
def select_cutlass_fp8_gemm_impl(
|
||||||
moe: FusedMoEConfig | None,
|
moe: FusedMoEConfig | None,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
out_dtype: torch.dtype | None = None,
|
out_dtype: torch.dtype | None = None,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||||
|
|
||||||
@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl(
|
|||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
tp_size=moe.moe_parallel_config.tp_size,
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
|
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8(
|
|||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_deepseek_fp8_block_scale: bool = False,
|
||||||
|
moe: FusedMoEConfig | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||||
assert quant_config is not None
|
assert quant_config is not None
|
||||||
|
|
||||||
|
# Construct modular kernel with block-scale support when requested.
|
||||||
fused_experts = mk.FusedMoEModularKernel(
|
fused_experts = mk.FusedMoEModularKernel(
|
||||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
|
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||||
|
),
|
||||||
select_cutlass_fp8_gemm_impl(
|
select_cutlass_fp8_gemm_impl(
|
||||||
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
|
moe=moe,
|
||||||
|
quant_config=quant_config,
|
||||||
|
out_dtype=hidden_states.dtype,
|
||||||
|
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8(
|
|||||||
|
|
||||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||||
if flashinfer_moe_backend == "throughput":
|
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
|
||||||
|
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
|
||||||
|
90
|
||||||
|
):
|
||||||
return FlashinferMoeBackend.CUTLASS
|
return FlashinferMoeBackend.CUTLASS
|
||||||
elif flashinfer_moe_backend == "latency":
|
elif flashinfer_moe_backend == "latency":
|
||||||
return FlashinferMoeBackend.TENSORRT_LLM
|
return FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user