diff --git a/vllm/envs.py b/vllm/envs.py index 8b12a7ee2b988..f81f6dacd87cd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,6 +129,7 @@ if TYPE_CHECKING: VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False + VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -982,6 +983,20 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both + # require compute capability 10.0 or above. + # Available options: + # - "throughput": [default] + # Uses CUTLASS kernels optimized for high-throughput batch inference. + # - "latency": + # Uses TensorRT-LLM kernels optimized for low-latency inference. + # To set this backend, define the environment variable: + # export VLLM_FLASHINFER_MOE_BACKEND=latency. + # If not set, defaults to "throughput". + "VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv( + "VLLM_FLASHINFER_MOE_BACKEND", "throughput" + ), + # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9e4ee5a3d7b95..f2242ade0c0f1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -192,7 +192,8 @@ class FusedMoEParallelConfig: @property def use_flashinfer_cutlass_kernels(self): return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe()) + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") @staticmethod def make(tp_size_: int, dp_size_: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 09d8890888fa8..c04f7c39a5f5d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -105,7 +105,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): detect_nvfp4_moe_support) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 self.fused_experts = None # type: ignore[assignment] @@ -212,7 +212,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): requires_grad=False) # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. - if self.allow_flashinfer_cutlass: + if self.allow_flashinfer: w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2) @@ -266,7 +266,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): (layer.w2_input_global_scale), requires_grad=False) def maybe_swap_experts_impl(self, moe_parallel_config): - if not self.allow_flashinfer_cutlass: + if not self.allow_flashinfer: return self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( moe_parallel_config) @@ -277,8 +277,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 select_nvfp4_gemm_impl) - return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, - logger) + return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) def apply( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0334a2824512d..147b275eaf525 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Any, Callable, Optional, Union import torch @@ -36,6 +37,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.scalar_type import scalar_types +from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import has_flashinfer_moe logger = init_logger(__name__) @@ -44,6 +46,11 @@ QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] +class FlashinferMoeBackend(Enum): + TENSORRT_LLM = "TensorRT-LLM" + CUTLASS = "CUTLASS" + + class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" @@ -185,7 +192,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): Args: quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config): + def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) @@ -265,7 +272,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptFp8Config): + def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) @@ -670,7 +677,8 @@ class ModelOptNvFp4Config(QuantizationConfig): return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) - def is_layer_excluded(self, prefix: str, exclude_modules: list): + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: import regex as re for pattern in exclude_modules: regex_str = pattern.replace('.', r'\.').replace('*', r'.*') @@ -714,7 +722,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): Args: quant_config: The ModelOpt quantization config. """ - def __init__(self, quant_config: ModelOptNvFp4Config): + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.use_marlin = False @@ -859,6 +867,16 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): return out.view(*output_shape) +def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int: + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. @@ -866,22 +884,40 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): quant_config: NVFP4 Quant Config """ - def __init__(self, quant_config: ModelOptNvFp4Config): + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported - self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass + self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin + self.flashinfer_moe_backend = None - self.fused_experts = None # type: ignore + if self.allow_flashinfer: + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_moe_backend == "throughput": + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + logger.info_once("Using FlashInfer CUTLASS kernels for " + "ModelOptNvFp4FusedMoE.") + elif flashinfer_moe_backend == "latency": + self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM + logger.info_once("Using FlashInfer TensorRT-LLM kernels for " + "ModelOptNvFp4FusedMoE.") + else: + allowed_backends = ["throughput", "latency"] + raise ValueError( + f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" + f" expected one of {allowed_backends}") + + self.fused_experts: Optional[ + mk.FusedMoEModularKernel] = None # type: ignore[assignment] def maybe_swap_experts_impl( self, moe_parallel_config: FusedMoEParallelConfig, ): - if not self.allow_flashinfer_cutlass: + if not self.allow_flashinfer: return self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel( moe_parallel_config) @@ -897,8 +933,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 select_nvfp4_gemm_impl) - return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe, - logger) + return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) def uses_weight_scale_2_pattern(self) -> bool: """ @@ -996,14 +1031,101 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): weight_loader=weight_loader) layer.register_parameter("w2_input_scale", w2_input_scale) + def prepare_static_weight_layouts_for_trtllm_moe( + self, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm1_scales_linear_fp4_bytes: torch.Tensor, + gemm2_scales_linear_fp4_bytes: torch.Tensor, + hidden_size: int, + intermediate_size: int, + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepare quantized weights for kernel (done offline with weights).""" + from flashinfer import (reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, shuffle_matrix_sf_a) + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + + # Convert quantized weights to proper formats + gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4 + gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size, + hidden_size // + 16) # fp8 scaling factors + + gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 2) # packed fp4 + gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn).reshape(num_experts, hidden_size, + intermediate_size // + 16) # fp8 scaling factors + + # Reorder rows of W1 and scales for fused gated activation + gemm1_weights_fp4_interleaved = [] + gemm1_scales_fp4_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())) + gemm1_scales_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm( + gemm1_scales_linear_fp4[i].clone())) + + # Stack weights and scales for all experts + gemm1_weights_fp4_interleaved = torch.stack( + gemm1_weights_fp4_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 2) + gemm1_scales_fp4_interleaved = torch.stack( + gemm1_scales_fp4_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 16) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp4_shuffled = [] + gemm1_scales_fp4_shuffled = [] + gemm2_weights_fp4_shuffled = [] + gemm2_scales_fp4_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp4_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm1_scales_fp4_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_fp4_shuffled.append( + shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m)) + + # Stack weights for all experts + gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_scales_fp4_shuffled = ( + torch.stack(gemm1_scales_fp4_shuffled).view( + torch.float8_e4m3fn).reshape(num_experts, + 2 * intermediate_size, + hidden_size // 16)) + + gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_scales_fp4_shuffled = ( + torch.stack(gemm2_scales_fp4_shuffled).view( + torch.float8_e4m3fn).reshape(num_experts, hidden_size, + intermediate_size // 16)) + return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 - # The FlashInfer Cutlass fused MoE kernel expects the combined weights - # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. + # GEMM 1 processing gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data - if self.allow_flashinfer_cutlass: + if self.allow_flashinfer: gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2) @@ -1011,6 +1133,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) + # Common processing for w13_weight_scale_2 if not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning_once( @@ -1021,26 +1144,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + # Common processing for input scales and alphas w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False) - assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) - - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) - # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) - # GEMM 2 + # GEMM 2 processing layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False) @@ -1049,15 +1164,63 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_input_scale_quant = Parameter( (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) - assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + # TensorRT-LLM specific processing + if self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # Prepare static weights for TRT-LLM kernel + (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled + ) = self.prepare_static_weight_layouts_for_trtllm_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + layer.gemm1_weights_fp4_shuffled = Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False) + layer.gemm2_weights_fp4_shuffled = Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False) + layer.gemm1_scales_fp4_shuffled = Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False) + layer.gemm2_scales_fp4_shuffled = Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to( + torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + else: + # Non-TRT-LLM processing (Cutlass or non-flashinfer) + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = swizzle_blockscale( + layer.w13_weight_scale) + layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) + layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, + requires_grad=False) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) @@ -1095,6 +1258,60 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") assert activation == "silu", "Only SiLU activation is supported." + if self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + import flashinfer + + from vllm.model_executor.models.llama4 import Llama4MoE + + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, + hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + use_llama4_routing = \ + custom_routing_function is Llama4MoE.custom_routing_function + routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 + if use_llama4_routing: + routing_method_type = flashinfer.RoutingMethodType.Llama4 + out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits + if use_llama4_routing else router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + hidden_states=hidden_states_fp4, + hidden_states_scale=hidden_states_scale_linear_fp4.view( + torch.float8_e4m3fn).flatten(), + gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c.data, + output1_scale_gate_scalar=layer.g1_alphas.data, + output2_scale_scalar=layer.g2_alphas.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=None, + tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, + layer.local_num_experts), + routing_method_type=routing_method_type, + do_finalize=True, + )[0] + return out + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -1149,6 +1366,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) else: + assert self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS out = flashinfer_fp4_cutlass_moe_forward( self.fused_experts, layer, @@ -1160,4 +1379,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) + return out diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 4c617e226041f..8ef91eeed406f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -126,7 +126,7 @@ def flashinfer_fp4_cutlass_moe_forward( def select_nvfp4_gemm_impl( - allow_flashinfer_cutlass: bool, + allow_flashinfer: bool, moe, # FusedMoEConfig logger): """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" @@ -137,8 +137,14 @@ def select_nvfp4_gemm_impl( all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - if allow_flashinfer_cutlass: - logger.debug_once("Using FlashInferExperts") + if allow_flashinfer: + flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND + if flashinfer_backend != "throughput": + raise ValueError( + f"Only throughput backend is supported for FlashInferExperts, " + f"but got {flashinfer_backend}.") + logger.debug_once( + "Initializing FlashInferExperts with throughput backend.") return FlashInferExperts( use_nvfp4_w4a4=True, use_dp=moe.moe_parallel_config.dp_size > 1, diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index 23a749467f193..21af74c6b72b5 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -21,7 +21,7 @@ class NvFp4Support: """Result container for NV-FP4 capability probing.""" cutlass_supported: bool - allow_flashinfer_cutlass: bool + allow_flashinfer: bool use_marlin: bool @@ -54,6 +54,6 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: return NvFp4Support( cutlass_supported=cutlass_supported, - allow_flashinfer_cutlass=allow_flashinfer, + allow_flashinfer=allow_flashinfer, use_marlin=use_marlin, ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 32c52612ca16f..5998d4c3127f6 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -86,6 +86,8 @@ flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer", "nvfp4_block_scale_interleave") +trtllm_fp4_block_scale_moe = _lazy_import_wrapper( + "flashinfer", "trtllm_fp4_block_scale_moe") # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( @@ -112,6 +114,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: ("flashinfer.fused_moe", "cutlass_fused_moe"), ("flashinfer", "fp4_quantize"), ("flashinfer", "nvfp4_block_scale_interleave"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), ] for module_name, attr_name in required_functions: @@ -188,6 +191,7 @@ __all__ = [ "flashinfer_cutlass_fused_moe", "fp4_quantize", "nvfp4_block_scale_interleave", + "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", "has_flashinfer_cutlass_fused_moe",