mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 04:55:01 +08:00
Support Tensorrt-LLM MoE fp4 for low-latency (#21331)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com>
This commit is contained in:
parent
d57dc2364e
commit
a3b9c17b56
15
vllm/envs.py
15
vllm/envs.py
@ -129,6 +129,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||||
|
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
|
||||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||||
@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ALL2ALL_BACKEND":
|
"VLLM_ALL2ALL_BACKEND":
|
||||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||||
|
|
||||||
|
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
|
||||||
|
# require compute capability 10.0 or above.
|
||||||
|
# Available options:
|
||||||
|
# - "throughput": [default]
|
||||||
|
# Uses CUTLASS kernels optimized for high-throughput batch inference.
|
||||||
|
# - "latency":
|
||||||
|
# Uses TensorRT-LLM kernels optimized for low-latency inference.
|
||||||
|
# To set this backend, define the environment variable:
|
||||||
|
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
|
||||||
|
# If not set, defaults to "throughput".
|
||||||
|
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
|
||||||
|
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
|
||||||
|
),
|
||||||
|
|
||||||
# Control the maximum number of tokens per expert supported by the
|
# Control the maximum number of tokens per expert supported by the
|
||||||
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
||||||
# the blockscale tensor of activations NVFP4 Quantization.
|
# the blockscale tensor of activations NVFP4 Quantization.
|
||||||
|
|||||||
@ -192,7 +192,8 @@ class FusedMoEParallelConfig:
|
|||||||
@property
|
@property
|
||||||
def use_flashinfer_cutlass_kernels(self):
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||||
and has_flashinfer_cutlass_fused_moe())
|
and has_flashinfer_cutlass_fused_moe()
|
||||||
|
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(tp_size_: int, dp_size_: int,
|
def make(tp_size_: int, dp_size_: int,
|
||||||
|
|||||||
@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
detect_nvfp4_moe_support)
|
detect_nvfp4_moe_support)
|
||||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
self.fused_experts = None # type: ignore[assignment]
|
self.fused_experts = None # type: ignore[assignment]
|
||||||
@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
||||||
if self.allow_flashinfer_cutlass:
|
if self.allow_flashinfer:
|
||||||
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
|
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
|
||||||
layer.w13_weight_scale.data,
|
layer.w13_weight_scale.data,
|
||||||
dim=-2)
|
dim=-2)
|
||||||
@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
(layer.w2_input_global_scale), requires_grad=False)
|
(layer.w2_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
def maybe_swap_experts_impl(self, moe_parallel_config):
|
def maybe_swap_experts_impl(self, moe_parallel_config):
|
||||||
if not self.allow_flashinfer_cutlass:
|
if not self.allow_flashinfer:
|
||||||
return
|
return
|
||||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||||
moe_parallel_config)
|
moe_parallel_config)
|
||||||
@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||||
select_nvfp4_gemm_impl)
|
select_nvfp4_gemm_impl)
|
||||||
|
|
||||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||||
logger)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.utils import next_power_of_2
|
||||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
|||||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||||
|
|
||||||
|
|
||||||
|
class FlashinferMoeBackend(Enum):
|
||||||
|
TENSORRT_LLM = "TensorRT-LLM"
|
||||||
|
CUTLASS = "CUTLASS"
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8Config(QuantizationConfig):
|
class ModelOptFp8Config(QuantizationConfig):
|
||||||
"""Config class for ModelOpt FP8."""
|
"""Config class for ModelOpt FP8."""
|
||||||
|
|
||||||
@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
Args: quant_config: The ModelOpt quantization config.
|
Args: quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config):
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.fp8_linear = Fp8LinearOp(
|
self.fp8_linear = Fp8LinearOp(
|
||||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||||
@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The ModelOpt quantization config.
|
quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config):
|
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_fp8_supported)
|
cutlass_fp8_supported)
|
||||||
@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||||
exclude_modules, group_size)
|
exclude_modules, group_size)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
def is_layer_excluded(self, prefix: str,
|
||||||
|
exclude_modules: list[str]) -> bool:
|
||||||
import regex as re
|
import regex as re
|
||||||
for pattern in exclude_modules:
|
for pattern in exclude_modules:
|
||||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||||
@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
Args: quant_config: The ModelOpt quantization config.
|
Args: quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
|
||||||
|
# Guess tokens per expert assuming perfect expert distribution first.
|
||||||
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
|
# And pad the number to the next power of 2.
|
||||||
|
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||||
|
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||||
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||||
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||||
"""
|
"""
|
||||||
MoE Method for FP4 Quantization.
|
MoE Method for FP4 Quantization.
|
||||||
@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
quant_config: NVFP4 Quant Config
|
quant_config: NVFP4 Quant Config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
detect_nvfp4_moe_support)
|
detect_nvfp4_moe_support)
|
||||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
|
self.flashinfer_moe_backend = None
|
||||||
|
|
||||||
self.fused_experts = None # type: ignore
|
if self.allow_flashinfer:
|
||||||
|
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||||
|
if flashinfer_moe_backend == "throughput":
|
||||||
|
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||||
|
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
||||||
|
"ModelOptNvFp4FusedMoE.")
|
||||||
|
elif flashinfer_moe_backend == "latency":
|
||||||
|
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
||||||
|
"ModelOptNvFp4FusedMoE.")
|
||||||
|
else:
|
||||||
|
allowed_backends = ["throughput", "latency"]
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||||
|
f" expected one of {allowed_backends}")
|
||||||
|
|
||||||
|
self.fused_experts: Optional[
|
||||||
|
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
||||||
|
|
||||||
def maybe_swap_experts_impl(
|
def maybe_swap_experts_impl(
|
||||||
self,
|
self,
|
||||||
moe_parallel_config: FusedMoEParallelConfig,
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
):
|
):
|
||||||
if not self.allow_flashinfer_cutlass:
|
if not self.allow_flashinfer:
|
||||||
return
|
return
|
||||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||||
moe_parallel_config)
|
moe_parallel_config)
|
||||||
@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||||
select_nvfp4_gemm_impl)
|
select_nvfp4_gemm_impl)
|
||||||
|
|
||||||
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
||||||
logger)
|
|
||||||
|
|
||||||
def uses_weight_scale_2_pattern(self) -> bool:
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def prepare_static_weight_layouts_for_trtllm_moe(
|
||||||
|
self,
|
||||||
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
|
gemm1_scales_linear_fp4_bytes: torch.Tensor,
|
||||||
|
gemm2_scales_linear_fp4_bytes: torch.Tensor,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
||||||
|
from flashinfer import (reorder_rows_for_gated_act_gemm,
|
||||||
|
shuffle_matrix_a, shuffle_matrix_sf_a)
|
||||||
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||||
|
|
||||||
|
# Convert quantized weights to proper formats
|
||||||
|
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4
|
||||||
|
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
||||||
|
torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
|
||||||
|
hidden_size //
|
||||||
|
16) # fp8 scaling factors
|
||||||
|
|
||||||
|
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
||||||
|
num_experts, hidden_size, intermediate_size // 2) # packed fp4
|
||||||
|
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
||||||
|
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||||
|
intermediate_size //
|
||||||
|
16) # fp8 scaling factors
|
||||||
|
|
||||||
|
# Reorder rows of W1 and scales for fused gated activation
|
||||||
|
gemm1_weights_fp4_interleaved = []
|
||||||
|
gemm1_scales_fp4_interleaved = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
gemm1_weights_fp4_interleaved.append(
|
||||||
|
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
||||||
|
gemm1_scales_fp4_interleaved.append(
|
||||||
|
reorder_rows_for_gated_act_gemm(
|
||||||
|
gemm1_scales_linear_fp4[i].clone()))
|
||||||
|
|
||||||
|
# Stack weights and scales for all experts
|
||||||
|
gemm1_weights_fp4_interleaved = torch.stack(
|
||||||
|
gemm1_weights_fp4_interleaved).reshape(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size // 2)
|
||||||
|
gemm1_scales_fp4_interleaved = torch.stack(
|
||||||
|
gemm1_scales_fp4_interleaved).reshape(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size // 16)
|
||||||
|
|
||||||
|
# Shuffle weights and scaling factors for transposed mma output
|
||||||
|
gemm1_weights_fp4_shuffled = []
|
||||||
|
gemm1_scales_fp4_shuffled = []
|
||||||
|
gemm2_weights_fp4_shuffled = []
|
||||||
|
gemm2_scales_fp4_shuffled = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
gemm1_weights_fp4_shuffled.append(
|
||||||
|
shuffle_matrix_a(
|
||||||
|
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
gemm1_scales_fp4_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
|
gemm2_weights_fp4_shuffled.append(
|
||||||
|
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
gemm2_scales_fp4_shuffled.append(
|
||||||
|
shuffle_matrix_sf_a(
|
||||||
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
|
# Stack weights for all experts
|
||||||
|
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||||
|
gemm1_scales_fp4_shuffled = (
|
||||||
|
torch.stack(gemm1_scales_fp4_shuffled).view(
|
||||||
|
torch.float8_e4m3fn).reshape(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size // 16))
|
||||||
|
|
||||||
|
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
||||||
|
gemm2_scales_fp4_shuffled = (
|
||||||
|
torch.stack(gemm2_scales_fp4_shuffled).view(
|
||||||
|
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||||
|
intermediate_size // 16))
|
||||||
|
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||||
|
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# GEMM 1
|
# GEMM 1 processing
|
||||||
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
|
||||||
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
|
|
||||||
gemm1_weight = layer.w13_weight.data
|
gemm1_weight = layer.w13_weight.data
|
||||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||||
|
|
||||||
if self.allow_flashinfer_cutlass:
|
if self.allow_flashinfer:
|
||||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||||
gemm1_weight, gemm1_weight_scale, dim=-2)
|
gemm1_weight, gemm1_weight_scale, dim=-2)
|
||||||
|
|
||||||
@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Common processing for w13_weight_scale_2
|
||||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||||
layer.w13_weight_scale_2[:, 1]):
|
layer.w13_weight_scale_2[:, 1]):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Common processing for input scales and alphas
|
||||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||||
torch.float32)
|
torch.float32)
|
||||||
layer.g1_alphas = Parameter(
|
layer.g1_alphas = Parameter(
|
||||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
|
||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
|
||||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
|
||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
|
||||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
|
||||||
|
|
||||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
# This is for quantization, so we need to invert it.
|
# This is for quantization, so we need to invert it.
|
||||||
layer.w13_input_scale_quant = Parameter(
|
layer.w13_input_scale_quant = Parameter(
|
||||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||||
|
|
||||||
# GEMM 2
|
# GEMM 2 processing
|
||||||
layer.g2_alphas = Parameter(
|
layer.g2_alphas = Parameter(
|
||||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w2_input_scale_quant = Parameter(
|
layer.w2_input_scale_quant = Parameter(
|
||||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
||||||
|
|
||||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
# TensorRT-LLM specific processing
|
||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
if self.allow_flashinfer and \
|
||||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
# Prepare static weights for TRT-LLM kernel
|
||||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
|
||||||
|
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
|
||||||
|
) = self.prepare_static_weight_layouts_for_trtllm_moe(
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
layer.w2_weight.size(-2), # hidden_size
|
||||||
|
layer.w13_weight.size(-2) // 2, # intermediate_size
|
||||||
|
layer.w13_weight.size(0), # num_experts
|
||||||
|
)
|
||||||
|
|
||||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
layer.gemm1_weights_fp4_shuffled = Parameter(
|
||||||
requires_grad=False)
|
gemm1_weights_fp4_shuffled, requires_grad=False)
|
||||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
layer.gemm2_weights_fp4_shuffled = Parameter(
|
||||||
|
gemm2_weights_fp4_shuffled, requires_grad=False)
|
||||||
|
layer.gemm1_scales_fp4_shuffled = Parameter(
|
||||||
|
gemm1_scales_fp4_shuffled, requires_grad=False)
|
||||||
|
layer.gemm2_scales_fp4_shuffled = Parameter(
|
||||||
|
gemm2_scales_fp4_shuffled, requires_grad=False)
|
||||||
|
|
||||||
|
# Additional parameter needed for TRT-LLM
|
||||||
|
layer.g1_scale_c = Parameter(
|
||||||
|
(layer.w2_input_scale_quant * layer.g1_alphas).to(
|
||||||
|
torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up weights that won't be used by TRT-LLM
|
||||||
|
del layer.w2_weight
|
||||||
|
del layer.w2_weight_scale
|
||||||
|
del layer.w13_weight
|
||||||
|
del layer.w13_weight_scale
|
||||||
|
else:
|
||||||
|
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||||
|
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||||
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
|
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
|
w13_blockscale_swizzled = swizzle_blockscale(
|
||||||
|
layer.w13_weight_scale)
|
||||||
|
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||||
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
|
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||||
|
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
@ -1095,6 +1258,60 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
|
if self.allow_flashinfer and \
|
||||||
|
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
|
||||||
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
|
(hidden_states_fp4,
|
||||||
|
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||||
|
x,
|
||||||
|
a1_gscale,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
use_llama4_routing = \
|
||||||
|
custom_routing_function is Llama4MoE.custom_routing_function
|
||||||
|
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
||||||
|
if use_llama4_routing:
|
||||||
|
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||||
|
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||||
|
routing_logits=router_logits
|
||||||
|
if use_llama4_routing else router_logits.to(torch.float32),
|
||||||
|
routing_bias=e_score_correction_bias,
|
||||||
|
hidden_states=hidden_states_fp4,
|
||||||
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||||
|
torch.float8_e4m3fn).flatten(),
|
||||||
|
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
|
||||||
|
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
|
||||||
|
torch.float8_e4m3fn),
|
||||||
|
gemm1_bias=None,
|
||||||
|
gemm1_alpha=None,
|
||||||
|
gemm1_beta=None,
|
||||||
|
gemm1_clamp_limit=None,
|
||||||
|
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
|
||||||
|
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
|
||||||
|
torch.float8_e4m3fn),
|
||||||
|
gemm2_bias=None,
|
||||||
|
output1_scale_scalar=layer.g1_scale_c.data,
|
||||||
|
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||||
|
output2_scale_scalar=layer.g2_alphas.data,
|
||||||
|
num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
n_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
|
local_num_experts=layer.local_num_experts,
|
||||||
|
routed_scaling_factor=None,
|
||||||
|
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
||||||
|
layer.local_num_experts),
|
||||||
|
routing_method_type=routing_method_type,
|
||||||
|
do_finalize=True,
|
||||||
|
)[0]
|
||||||
|
return out
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -1149,6 +1366,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
else:
|
||||||
|
assert self.allow_flashinfer and \
|
||||||
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
out = flashinfer_fp4_cutlass_moe_forward(
|
out = flashinfer_fp4_cutlass_moe_forward(
|
||||||
self.fused_experts,
|
self.fused_experts,
|
||||||
layer,
|
layer,
|
||||||
@ -1160,4 +1379,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward(
|
|||||||
|
|
||||||
|
|
||||||
def select_nvfp4_gemm_impl(
|
def select_nvfp4_gemm_impl(
|
||||||
allow_flashinfer_cutlass: bool,
|
allow_flashinfer: bool,
|
||||||
moe, # FusedMoEConfig
|
moe, # FusedMoEConfig
|
||||||
logger):
|
logger):
|
||||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||||
@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl(
|
|||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
if allow_flashinfer_cutlass:
|
if allow_flashinfer:
|
||||||
logger.debug_once("Using FlashInferExperts")
|
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||||
|
if flashinfer_backend != "throughput":
|
||||||
|
raise ValueError(
|
||||||
|
f"Only throughput backend is supported for FlashInferExperts, "
|
||||||
|
f"but got {flashinfer_backend}.")
|
||||||
|
logger.debug_once(
|
||||||
|
"Initializing FlashInferExperts with throughput backend.")
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
use_nvfp4_w4a4=True,
|
use_nvfp4_w4a4=True,
|
||||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class NvFp4Support:
|
|||||||
"""Result container for NV-FP4 capability probing."""
|
"""Result container for NV-FP4 capability probing."""
|
||||||
|
|
||||||
cutlass_supported: bool
|
cutlass_supported: bool
|
||||||
allow_flashinfer_cutlass: bool
|
allow_flashinfer: bool
|
||||||
use_marlin: bool
|
use_marlin: bool
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
|||||||
|
|
||||||
return NvFp4Support(
|
return NvFp4Support(
|
||||||
cutlass_supported=cutlass_supported,
|
cutlass_supported=cutlass_supported,
|
||||||
allow_flashinfer_cutlass=allow_flashinfer,
|
allow_flashinfer=allow_flashinfer,
|
||||||
use_marlin=use_marlin,
|
use_marlin=use_marlin,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
|||||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||||
"flashinfer", "nvfp4_block_scale_interleave")
|
"flashinfer", "nvfp4_block_scale_interleave")
|
||||||
|
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer", "trtllm_fp4_block_scale_moe")
|
||||||
|
|
||||||
# Special case for autotune since it returns a context manager
|
# Special case for autotune since it returns a context manager
|
||||||
autotune = _lazy_import_wrapper(
|
autotune = _lazy_import_wrapper(
|
||||||
@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
|||||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||||
("flashinfer", "fp4_quantize"),
|
("flashinfer", "fp4_quantize"),
|
||||||
("flashinfer", "nvfp4_block_scale_interleave"),
|
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||||
|
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for module_name, attr_name in required_functions:
|
for module_name, attr_name in required_functions:
|
||||||
@ -188,6 +191,7 @@ __all__ = [
|
|||||||
"flashinfer_cutlass_fused_moe",
|
"flashinfer_cutlass_fused_moe",
|
||||||
"fp4_quantize",
|
"fp4_quantize",
|
||||||
"nvfp4_block_scale_interleave",
|
"nvfp4_block_scale_interleave",
|
||||||
|
"trtllm_fp4_block_scale_moe",
|
||||||
"autotune",
|
"autotune",
|
||||||
"has_flashinfer_moe",
|
"has_flashinfer_moe",
|
||||||
"has_flashinfer_cutlass_fused_moe",
|
"has_flashinfer_cutlass_fused_moe",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user