mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 10:07:14 +08:00
[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
511a6b611d
commit
3f8a874065
@ -57,6 +57,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
use_dp: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||
@ -69,6 +70,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.tp_size = tp_size
|
||||
self.out_dtype = out_dtype
|
||||
self.use_dp = use_dp
|
||||
# Enables DeepSeek-style FP8 block-scale path:
|
||||
# - pass per-block weight scales to the kernel
|
||||
# - skip input activation quantization (kernel applies scaling)
|
||||
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -147,7 +152,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"Only activation silu is supported in FlashInferExperts"
|
||||
)
|
||||
|
||||
if self.quant_dtype == torch.float8_e4m3fn:
|
||||
# Select quantization metadata based on FP8 format/path
|
||||
if (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
and not self.use_deepseek_fp8_block_scale
|
||||
):
|
||||
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
|
||||
quant_scales = [
|
||||
self.g1_alphas,
|
||||
self.a2_gscale,
|
||||
@ -176,6 +186,15 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
fc1_expert_weights = w1.view(torch.long)
|
||||
fc2_expert_weights = w2.view(torch.long)
|
||||
elif self.use_deepseek_fp8_block_scale:
|
||||
# FP8 block-scale path: provide block-scale weights, omit a1q_scale
|
||||
quant_scales = [
|
||||
self.w1_scale,
|
||||
self.w2_scale,
|
||||
]
|
||||
a1q_scale = None
|
||||
fc1_expert_weights = w1
|
||||
fc2_expert_weights = w2
|
||||
else:
|
||||
quant_scales = None
|
||||
a1q_scale = None
|
||||
@ -196,6 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
# Informs FlashInfer to use the block-scale decoding path when True
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -28,11 +28,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.use_dp = use_dp
|
||||
self.local_tokens = None
|
||||
# Toggle for DeepSeek-style FP8 block-scale path where activations are
|
||||
# not quantized here and weight block scales are consumed by the kernel.
|
||||
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
@ -73,8 +77,9 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers)
|
||||
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||
self.alltoall_info = None
|
||||
|
||||
# Initialize all2all_manager only for DP case
|
||||
@ -97,15 +102,19 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
||||
)
|
||||
|
||||
if not self.use_dp:
|
||||
# Non-DP case: standard quantization
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
# Non-DP case: quantize activations unless using block-scale path
|
||||
if not self.use_deepseek_fp8_block_scale:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
else:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
# DP case: use FlashInfer AllToAll
|
||||
global_num_tokens_cpu = get_local_sizes()
|
||||
@ -122,6 +131,7 @@ class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFina
|
||||
top_k,
|
||||
num_experts,
|
||||
quant_config,
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
)
|
||||
)
|
||||
|
||||
@ -154,8 +164,9 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers)
|
||||
super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -173,22 +184,42 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
if not self.use_deepseek_fp8_block_scale:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
else:
|
||||
# Block-scale path: pass activations through, omit per-token scales
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
|
||||
if self.use_dp:
|
||||
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
|
||||
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
# Build gather list conditionally - omit a1q_scale if None
|
||||
# (block-scale path)
|
||||
gather_list = [topk_weights, topk_ids, a1q]
|
||||
if a1q_scale is not None:
|
||||
gather_list.append(a1q_scale)
|
||||
gathered = get_dp_group().all_gatherv(
|
||||
gather_list,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
topk_weights, topk_ids, a1q, a1q_scale = gathered
|
||||
else:
|
||||
gathered = get_dp_group().all_gatherv(
|
||||
gather_list,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
topk_weights, topk_ids, a1q = gathered
|
||||
a1q_scale = None
|
||||
|
||||
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
@ -221,6 +252,7 @@ def flashinfer_alltoall_dispatch(
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
@ -250,30 +282,42 @@ def flashinfer_alltoall_dispatch(
|
||||
)
|
||||
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
|
||||
|
||||
x, x_sf = moe_kernel_quantize_input(
|
||||
x,
|
||||
gs,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||
)
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
if not use_deepseek_fp8_block_scale:
|
||||
x, x_sf = moe_kernel_quantize_input(
|
||||
x,
|
||||
gs,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||
)
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
else:
|
||||
# Block-scale path: pass activations through without quantization
|
||||
x_sf = None
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
||||
|
||||
|
||||
@ -304,6 +348,7 @@ def create_flashinfer_prepare_finalize(
|
||||
use_dp: bool,
|
||||
use_nvfp4: bool = False,
|
||||
enable_alltoallv: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
) -> FlashInferCutlassMoEPrepareAndFinalize:
|
||||
"""Factory function to create the appropriate FlashInfer implementation."""
|
||||
if use_nvfp4:
|
||||
@ -311,5 +356,7 @@ def create_flashinfer_prepare_finalize(
|
||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||
else:
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||
# Fp8 only supports AllGather
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||
# FP8 path currently supported via AllGather; optionally enable block-scale
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
use_dp=use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -122,10 +123,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# prefer FlashInfer backends when available and enabled on supported GPUs
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
@ -134,14 +138,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant:
|
||||
if block_quant and current_platform.is_device_capability(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
@ -641,6 +645,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
)
|
||||
self.flashinfer_moe_fn = partial(
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
moe=self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||
@ -1012,8 +1026,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
return None
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
)
|
||||
# Wire block-scale flag through prepare/finalize when using CUTLASS
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
@ -1062,9 +1083,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
# Select GEMM experts with block-scale when weights are block-quantized
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
@ -1251,16 +1274,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not self.block_quant
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
)
|
||||
|
||||
result = flashinfer_cutlass_moe_fp8(
|
||||
if not self.block_quant:
|
||||
assert not renormalize and custom_routing_function is not None
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
)
|
||||
# Delegate to CUTLASS FlashInfer path; function already bound with
|
||||
# use_deepseek_fp8_block_scale for block-quant when applicable
|
||||
result = self.flashinfer_moe_fn(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig | None,
|
||||
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return create_flashinfer_prepare_finalize(use_dp)
|
||||
# Propagate block-scale flag so prepare/finalize can skip act quantization
|
||||
# and inform the kernel to consume per-block weight scales.
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: FusedMoEConfig | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl(
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
|
||||
return FlashInferExperts(
|
||||
out_dtype=out_dtype,
|
||||
quant_config=quant_config,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
|
||||
@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
moe: FusedMoEConfig | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
# Construct modular kernel with block-scale support when requested.
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
),
|
||||
select_cutlass_fp8_gemm_impl(
|
||||
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
|
||||
moe=moe,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
),
|
||||
)
|
||||
|
||||
@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8(
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
|
||||
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
|
||||
90
|
||||
):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user