mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:25:34 +08:00
[NVIDIA] Add SM100 Flashinfer MoE blockscale fp8 backend for low latency (#20645)
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
7d94577138
commit
6d0734c562
11
vllm/envs.py
11
vllm/envs.py
@ -119,7 +119,8 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
@ -854,9 +855,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE_FP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
|
||||
|
||||
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
|
||||
|
||||
# Control the cache sized used by the xgrammar compiler. The default
|
||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||
|
||||
@ -191,7 +191,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe())
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
@ -1061,6 +1061,104 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def next_positive_power_of_2(x: int) -> int:
|
||||
if x < 1:
|
||||
return 1
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@ -52,6 +53,11 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.reshape(-1, 2, x.shape[-2] // 2,
|
||||
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
@ -473,6 +479,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
self.flashinfer_moe_enabled = False
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
logger.info_once(
|
||||
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
|
||||
self.flashinfer_moe_enabled = True
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
@ -674,6 +685,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
# applied on different half for flashinfer vs vllm
|
||||
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
|
||||
w13_weight_scale_inv = _swap_w13_to_w31(
|
||||
layer.w13_weight_scale_inv.data)
|
||||
w2_weight = layer.w2_weight.data
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
@ -915,25 +934,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
if not self.flashinfer_moe_enabled:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
@ -971,6 +990,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
elif self.flashinfer_moe_enabled:
|
||||
# Currently only work with DS models
|
||||
assert self.block_quant
|
||||
assert (renormalize and use_grouped_topk
|
||||
and scoring_func == 'sigmoid'
|
||||
and custom_routing_function is None)
|
||||
assert activation == "silu"
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
routed_scaling=1.0,
|
||||
)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@ -721,7 +721,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.use_marlin = False
|
||||
self.allow_flashinfer_cutlass = False
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_MOE:
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
|
||||
and current_platform.is_device_capability(100):
|
||||
logger.info_once(
|
||||
@ -800,10 +800,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
assert moe.dp_size > 1
|
||||
logger.debug_once("Using CutlassExpertsFp4")
|
||||
# Currently CutlassExpertsFp4 doesn't support DP
|
||||
raise ValueError(
|
||||
"CutlassExpertsFp4 doesn't support DP. "
|
||||
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
|
||||
" backend instead.")
|
||||
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
|
||||
"Use flashinfer CUTLASS FusedMoE backend instead "
|
||||
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
@ -64,6 +64,8 @@ def _lazy_import_wrapper(module_name: str,
|
||||
|
||||
|
||||
# Create lazy wrappers for each function
|
||||
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||
"cutlass_fused_moe")
|
||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
@ -77,10 +79,16 @@ autotune = _lazy_import_wrapper(
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer MoE module is available."""
|
||||
return importlib.util.find_spec("flashinfer.fused_moe") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer():
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
@ -99,9 +107,11 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"fp4_quantize",
|
||||
"fp4_swizzle_blockscale",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user