mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:35:50 +08:00
Support Tensorrt-LLM MoE fp4 for low-latency (#21331)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
parent
d57dc2364e
commit
a3b9c17b56
15
vllm/envs.py
15
vllm/envs.py
@ -129,6 +129,7 @@ if TYPE_CHECKING:
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
|
||||
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
|
||||
# require compute capability 10.0 or above.
|
||||
# Available options:
|
||||
# - "throughput": [default]
|
||||
# Uses CUTLASS kernels optimized for high-throughput batch inference.
|
||||
# - "latency":
|
||||
# Uses TensorRT-LLM kernels optimized for low-latency inference.
|
||||
# To set this backend, define the environment variable:
|
||||
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
|
||||
# If not set, defaults to "throughput".
|
||||
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
|
||||
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
|
||||
),
|
||||
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
||||
# the blockscale tensor of activations NVFP4 Quantization.
|
||||
|
||||
@ -192,7 +192,8 @@ class FusedMoEParallelConfig:
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe())
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
|
||||
@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
detect_nvfp4_moe_support)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.fused_experts = None # type: ignore[assignment]
|
||||
@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False)
|
||||
|
||||
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
||||
if self.allow_flashinfer_cutlass:
|
||||
if self.allow_flashinfer:
|
||||
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
dim=-2)
|
||||
@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
(layer.w2_input_global_scale), requires_grad=False)
|
||||
|
||||
def maybe_swap_experts_impl(self, moe_parallel_config):
|
||||
if not self.allow_flashinfer_cutlass:
|
||||
if not self.allow_flashinfer:
|
||||
return
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||
logger)
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import next_power_of_2
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Config class for ModelOpt FP8."""
|
||||
|
||||
@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||
@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
import regex as re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
Args: quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization.
|
||||
@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.flashinfer_moe_backend = None
|
||||
|
||||
self.fused_experts = None # type: ignore
|
||||
if self.allow_flashinfer:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
else:
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}")
|
||||
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
||||
|
||||
def maybe_swap_experts_impl(
|
||||
self,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
):
|
||||
if not self.allow_flashinfer_cutlass:
|
||||
if not self.allow_flashinfer:
|
||||
return
|
||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||
moe_parallel_config)
|
||||
@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||
select_nvfp4_gemm_impl)
|
||||
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||
logger)
|
||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
"""
|
||||
@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def prepare_static_weight_layouts_for_trtllm_moe(
|
||||
self,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm1_scales_linear_fp4_bytes: torch.Tensor,
|
||||
gemm2_scales_linear_fp4_bytes: torch.Tensor,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||
from flashinfer import (reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a, shuffle_matrix_sf_a)
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
|
||||
# Convert quantized weights to proper formats
|
||||
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4
|
||||
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
|
||||
hidden_size //
|
||||
16) # fp8 scaling factors
|
||||
|
||||
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 2) # packed fp4
|
||||
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||
intermediate_size //
|
||||
16) # fp8 scaling factors
|
||||
|
||||
# Reorder rows of W1 and scales for fused gated activation
|
||||
gemm1_weights_fp4_interleaved = []
|
||||
gemm1_scales_fp4_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
||||
gemm1_scales_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
gemm1_scales_linear_fp4[i].clone()))
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp4_interleaved = torch.stack(
|
||||
gemm1_weights_fp4_interleaved).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 2)
|
||||
gemm1_scales_fp4_interleaved = torch.stack(
|
||||
gemm1_scales_fp4_interleaved).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 16)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_scales_fp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
gemm1_scales_fp4_shuffled = (
|
||||
torch.stack(gemm1_scales_fp4_shuffled).view(
|
||||
torch.float8_e4m3fn).reshape(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 16))
|
||||
|
||||
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||
gemm2_scales_fp4_shuffled = (
|
||||
torch.stack(gemm2_scales_fp4_shuffled).view(
|
||||
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||
intermediate_size // 16))
|
||||
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1
|
||||
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
||||
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
|
||||
# GEMM 1 processing
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer_cutlass:
|
||||
if self.allow_flashinfer:
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2)
|
||||
|
||||
@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# Common processing for w13_weight_scale_2
|
||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||
layer.w13_weight_scale_2[:, 1]):
|
||||
logger.warning_once(
|
||||
@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
# Common processing for input scales and alphas
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||
torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
# GEMM 2 processing
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
# TensorRT-LLM specific processing
|
||||
if self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
# Prepare static weights for TRT-LLM kernel
|
||||
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
||||
) = self.prepare_static_weight_layouts_for_trtllm_moe(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight.size(-2), # hidden_size
|
||||
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||
layer.w13_weight.size(0), # num_experts
|
||||
)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||
gemm1_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||
gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||
gemm1_scales_fp4_shuffled, requires_grad=False)
|
||||
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||
gemm2_scales_fp4_shuffled, requires_grad=False)
|
||||
|
||||
# Additional parameter needed for TRT-LLM
|
||||
layer.g1_scale_c = Parameter(
|
||||
(layer.w2_input_scale_quant * layer.g1_alphas).to(
|
||||
torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Clean up weights that won't be used by TRT-LLM
|
||||
del layer.w2_weight
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
@ -1095,6 +1258,60 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4,
|
||||
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
use_llama4_routing = \
|
||||
custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits
|
||||
if use_llama4_routing else router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn).flatten(),
|
||||
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||
torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
||||
layer.local_num_experts),
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return out
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -1149,6 +1366,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
assert self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
out = flashinfer_fp4_cutlass_moe_forward(
|
||||
self.fused_experts,
|
||||
layer,
|
||||
@ -1160,4 +1379,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward(
|
||||
|
||||
|
||||
def select_nvfp4_gemm_impl(
|
||||
allow_flashinfer_cutlass: bool,
|
||||
allow_flashinfer: bool,
|
||||
moe, # FusedMoEConfig
|
||||
logger):
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl(
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
if allow_flashinfer_cutlass:
|
||||
logger.debug_once("Using FlashInferExperts")
|
||||
if allow_flashinfer:
|
||||
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_backend != "throughput":
|
||||
raise ValueError(
|
||||
f"Only throughput backend is supported for FlashInferExperts, "
|
||||
f"but got {flashinfer_backend}.")
|
||||
logger.debug_once(
|
||||
"Initializing FlashInferExperts with throughput backend.")
|
||||
return FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
|
||||
@ -21,7 +21,7 @@ class NvFp4Support:
|
||||
"""Result container for NV-FP4 capability probing."""
|
||||
|
||||
cutlass_supported: bool
|
||||
allow_flashinfer_cutlass: bool
|
||||
allow_flashinfer: bool
|
||||
use_marlin: bool
|
||||
|
||||
|
||||
@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||
|
||||
return NvFp4Support(
|
||||
cutlass_supported=cutlass_supported,
|
||||
allow_flashinfer_cutlass=allow_flashinfer,
|
||||
allow_flashinfer=allow_flashinfer,
|
||||
use_marlin=use_marlin,
|
||||
)
|
||||
|
||||
@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave")
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe")
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
@ -188,6 +191,7 @@ __all__ = [
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"fp4_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user