# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale, ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe, ) if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] class ModelOptFp8Config(QuantizationConfig): """Config class for ModelOpt FP8.""" def __init__( self, is_checkpoint_fp8_serialized: bool = False, kv_cache_quant_method: str | None = None, exclude_modules: list[str] | None = None, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change." ) @classmethod def get_name(cls) -> QuantizationMethods: return "modelopt" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.exclude_modules is not None: self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: """Detect if this ModelOpt config should be used based on quantization config.""" if hf_quant_cfg is None: return None # Use the community standard 'quant_method' quant_method = hf_quant_cfg.get("quant_method", "").lower() # Only proceed if the method is explicitly "modelopt" if quant_method != "modelopt": return None # Look for ModelOpt-specific config structure if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): quant_algo = quant_config.get("quant_algo", "") if "FP8" in quant_algo: return "modelopt" else: # Check for compressed-tensors style config with specific quant_algo quant_algo = hf_quant_cfg.get("quant_algo", "") if isinstance(quant_algo, str) and "FP8" in quant_algo: return "modelopt" return None @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": # Handle both ModelOpt format and compressed-tensors style format if "quantization" in config: # ModelOpt format: {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules") else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo", "") kv_cache_quant_method = config.get("kv_cache_quant_algo") # "ignore" is the key in config.json exclude_modules = config.get("ignore") if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration." ) is_checkpoint_fp8_serialized = "FP8" in quant_method return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and substring matching. This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the module name (without the language_model prefix) is in the exclude list. """ if self.exclude_modules is None: return False # First check exact matching with fused layer support if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # Then check substring matching for patterns not caught by exact match for module in self.exclude_modules: # Skip exact matches already handled above if module != prefix and ( module in prefix or ( prefix.startswith("language_model.") and module in prefix.removeprefix("language_model.") ) ): return True return False def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): if self.is_layer_excluded(prefix): return UnquantizedLinearMethod() # Check if this is a vision model layer that should not be quantized if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): return ModelOptFp8MoEMethod(self, layer) return None class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): max_w_scale, weight = requantize_with_max_scale( layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, ) class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config. """ def __init__( self, quant_config: ModelOptFp8Config, layer: torch.nn.Module, ) -> None: super().__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported, ) self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" ) def maybe_make_prepare_finalize( self, ) -> mk.FusedMoEPrepareAndFinalize | None: # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( self.moe ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None experts = select_cutlass_fp8_gemm_impl( self.moe, self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): # Use FP8 dtype if checkpoint is serialized weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight_loader = extra_weight_attrs.get("weight_loader") w13_weight = ModelWeightParameter( data=torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALES - Per-tensor scaling for ModelOpts # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_weight_scale = PerTensorScaleParameter( data=torch.full( (num_experts, 2), 1.0, dtype=torch.float32, ), weight_loader=weight_loader, ) w2_weight_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Set weight loader attributes for scales extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) # INPUT SCALES - Per-tensor scaling for ModelOpt w13_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) w2_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales. """ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( per_tensor_dequantize, ) # Handle scale parameters if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. if layer.w13_weight_scale.dim() == 2: # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values # Requantize each expert's weights using the combined scale # w13_weight (num_experts, 2 * intermediate_size, hidden_size) # where the first intermediate_size rows are w1, the next are w3 intermediate_size = layer.w13_weight.shape[1] // 2 for expert_id in range(layer.w13_weight.shape[0]): start = 0 for shard_id in range(2): # w1 and w3 # Dequantize using the original scale for this shard dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][ start : start + intermediate_size, : ], layer.w13_weight_scale[expert_id][shard_id], ) # Requantize using the combined max scale ( layer.w13_weight[expert_id][ start : start + intermediate_size, : ], _, ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) start += intermediate_size # Update the scale parameter to be per-expert layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) else: layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data, requires_grad=False ) if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data, requires_grad=False ) # Input scales must be equal for each expert in fp8 MoE layers. if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: layer.w13_input_scale = Parameter( layer.w13_input_scale.max(), requires_grad=False ) if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: layer.w2_input_scale = Parameter( layer.w2_input_scale.max(), requires_grad=False ) if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) register_moe_scaling_factors(layer) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, per_act_token_quant=False, ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: int | None = None, num_expert_group: int | None = None, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet." ) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: assert self.fused_experts is None assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, router_logits=router_logits, routing_bias=e_score_correction_bias, global_num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, apply_router_weight_on_input=apply_router_weight_on_input, ) # Expert selection topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) # # Note: the order here is important. self.fused_experts can override # cutlass or fused_experts. # if self.fused_experts is not None: return self.fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) return flashinfer_cutlass_moe_fp8( x, layer, topk_weights, topk_ids, inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) else: from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts assert self.moe_quant_config is not None return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, quant_config=self.moe_quant_config, global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) class ModelOptNvFp4Config(QuantizationConfig): """Config class for ModelOpt FP4.""" def __init__( self, is_checkpoint_nvfp4_serialized: bool, kv_cache_quant_algo: str | None, exclude_modules: list[str], group_size: int = 16, ) -> None: super().__init__() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( "Detected ModelOpt NVFP4 checkpoint. Please note that" " the format is experimental and could change in future." ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo self.exclude_modules = exclude_modules @classmethod def get_name(cls) -> QuantizationMethods: return "modelopt_fp4" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.exclude_modules is not None: self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: """Detect if this ModelOpt FP4 config should be used based on quantization config.""" if hf_quant_cfg is None: return None # Use the community standard 'quant_method' quant_method = hf_quant_cfg.get("quant_method", "").lower() # Only proceed if the method is explicitly "modelopt" if quant_method != "modelopt": return None # Look for ModelOpt-specific config structure if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): quant_algo = quant_config.get("quant_algo", "") if "NVFP4" in quant_algo: return "modelopt_fp4" else: # Check for compressed-tensors style config with specific # quant_algo field quant_algo = hf_quant_cfg.get("quant_algo", "") if isinstance(quant_algo, str) and "FP4" in quant_algo.upper(): return "modelopt_fp4" return None @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": # Handle both traditional ModelOpt format and compressed-tensors # style format if "quantization" in config: # Traditional ModelOpt format: # {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") # Handle kv_cache_quant_algo with proper type validation kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") if kv_cache_quant_algo_raw is None: # No KV cache quantization by default kv_cache_quant_algo = None elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: raise ValueError( f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_algo_raw)}" ) # Handle group_size with proper type validation group_size_raw = quant_config.get("group_size") if group_size_raw is None: group_size = 16 # Default value elif isinstance(group_size_raw, int): group_size = group_size_raw else: try: group_size = int(group_size_raw) except (ValueError, TypeError): raise ValueError( f"group_size must be an integer, got {type(group_size_raw)}" ) from None # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) if not isinstance(exclude_modules, list): raise ValueError( f"exclude_modules must be a list, got {type(exclude_modules)}" ) else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo", "") # Handle kv_cache_quant_algo with proper type validation kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo") if kv_cache_quant_algo_raw is None: # No KV cache quantization by default kv_cache_quant_algo = None elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: raise ValueError( f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_algo_raw)}" ) # Handle group_size with proper type validation group_size_raw = config.get("group_size") if group_size_raw is None: group_size = 16 # Default value elif isinstance(group_size_raw, int): group_size = group_size_raw else: try: group_size = int(group_size_raw) except (ValueError, TypeError): raise ValueError( f"group_size must be an integer, got {type(group_size_raw)}" ) from None # "ignore" is the key in config.json exclude_modules = config.get("ignore", []) if not isinstance(exclude_modules, list): raise ValueError( f"exclude_modules must be a list, got {type(exclude_modules)}" ) if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration." ) is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method # For FP4, these fields are required if is_checkpoint_nvfp4_serialized and "quantization" in config: # Check if required fields are present in the quantization config quant_config = config["quantization"] required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"NVFP4 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}" ) return cls( is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size, ) def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and pattern matching. """ # First check exact matching with fused layer support if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # Check regex pattern matching for patterns not caught by exact match import regex as re for pattern in self.exclude_modules: # Skip patterns that would be caught by exact matching if "*" in pattern or "." in pattern: regex_str = pattern.replace(".", r"\.").replace("*", r".*") if re.fullmatch(regex_str, prefix): return True return False def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import skip_layer = self.is_layer_excluded(prefix) if isinstance(layer, LinearBase): if skip_layer: return UnquantizedLinearMethod() # Check if this is a vision model layer that should not be quantized if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): if skip_layer: return None return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config): super().__init__(quant_config) class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure: input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config self.backend = "none" if envs.VLLM_NVFP4_GEMM_BACKEND is None: if has_flashinfer(): self.backend = "flashinfer-cutlass" elif cutlass_fp4_supported(): self.backend = "cutlass" elif is_fp4_marlin_supported(): self.backend = "marlin" elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): self.backend = envs.VLLM_NVFP4_GEMM_BACKEND assert has_flashinfer(), f"FlashInfer is required for {self.backend}" if self.backend == "none": raise ValueError( "No valid NVFP4 GEMM backend found. " "Please check your platform capability." ) logger.info_once(f"Using {self.backend} for NVFP4 GEMM") def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( "NVFP4 quantization was selected, " " dynamic quantization is not supported." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % 16 != 0: raise ValueError( "Unsupported model when in features size is not multiple of 16" ) # The nvfp4 weight is still represented as weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_nvfp4_serialized else params_dtype ) # Weight weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension layer.output_size_per_partition, layer.input_size_per_partition // 2, dtype=torch.uint8, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) # Input Weight Scale input_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("input_scale", input_scale) # Global Weight Scale weight_scale_2 = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("weight_scale_2", weight_scale_2) # Per Block Weight Scale weight_scale = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition // self.quant_config.group_size, dtype=weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: Module) -> None: # global scales: input_scale_2 = layer.input_scale.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) layer.alpha = Parameter( layer.input_scale * layer.weight_scale_2, requires_grad=False ) # Calculate `1 / input_scale` so that we don't need to do so at runtime layer.input_scale_inv = Parameter( (1 / layer.input_scale).to(torch.float32), requires_grad=False ) # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( "Weight Block scale must be represented as FP8-E4M3" ) if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call # shuffles ourselves. from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a weight = layer.weight.data weight_scale = layer.weight_scale.data epilogue_tile_m = 128 weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) weight_scale = ( shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) .reshape(weight_scale.shape) .view(torch.float8_e4m3fn) ) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.backend == "marlin": return apply_fp4_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, weight_scale_2=layer.weight_scale_2, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias, ) output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale assert x_fp4.dtype == torch.uint8 assert layer.weight.dtype == torch.uint8 assert x_blockscale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.alpha.dtype == torch.float32 mm_args = ( x_fp4, layer.weight, x_blockscale, layer.weight_scale, layer.alpha, output_dtype, ) if self.backend.startswith("flashinfer-"): backend_name = self.backend[len("flashinfer-") :] out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) else: assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: out = out + bias return out.view(*output_shape) def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int: # Guess tokens per expert assuming perfect expert distribution first. num_tokens_per_expert = (num_tokens * top_k) // num_experts # And pad the number to the next power of 2. tile_tokens_dim = next_power_of_2(num_tokens_per_expert) # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config """ def __init__( self, quant_config: ModelOptNvFp4Config, moe: FusedMoEConfig, layer: torch.nn.Module, ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support, ) super().__init__(moe) self.quant_config = quant_config self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.flashinfer_moe_backend = None self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" " for ModelOptNvFp4FusedMoE." ) def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None elif ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ): # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( self.moe ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None experts = select_nvfp4_gemm_impl( self.moe, self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts def uses_weight_scale_2_pattern(self) -> bool: """ FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. """ return True def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( "NVFP4 quantization was selected, " " dynamic quantization is not supported." ) layer.num_experts = num_experts layer.params_dtype = params_dtype layer.quant_config = self.quant_config weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) # GEMM 2 w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) w13_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} ) w13_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) w13_input_scale = PerTensorScaleParameter( data=torch.empty(num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) def prepare_static_weights_for_trtllm_fp4_moe( self, # args_dequant, # args, gemm1_weights, gemm2_weights, gemm1_scales_linear_fp4_bytes, gemm2_scales_linear_fp4_bytes, hidden_size, intermediate_size, num_experts, ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( _maybe_get_cached_w2_permute_indices, _maybe_get_cached_w3_w1_permute_indices, ) """Prepare quantized weights for kernel (done offline with weights).""" epilogue_tile_m = 128 # FIXME: this depends on the kernel internals # Convert quantized weights to proper formats gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( num_experts, 2 * intermediate_size, hidden_size // 2 ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( num_experts, 2 * intermediate_size, hidden_size // 16 ) # fp8 scaling factors gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( num_experts, hidden_size, intermediate_size // 2 ) # packed fp4 gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( num_experts, hidden_size, intermediate_size // 16 ) # fp8 scaling factors gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] gemm2_weights_fp4_shuffled = [] gemm2_scales_fp4_shuffled = [] for i in range(num_experts): # Calculate the permute indices for the following: # 1. Reorder rows of W1 and scales for fused gated activation # 2. Shuffle weights and scaling factors for transposed mma output # for both w3_w1 and w2 weights and scale factors permute_indices = _maybe_get_cached_w3_w1_permute_indices( self._cache_permute_indices, gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] .contiguous() ) permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( self._cache_permute_indices, gemm1_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm1_scales_fp4_shuffled.append( nvfp4_block_scale_interleave( gemm1_scales_linear_fp4[i] .view(torch.uint8)[ permute_sf_indices.to(gemm1_scales_linear_fp4.device) ] .contiguous() ) ) permute_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m, ) gemm2_weights_fp4_shuffled.append( gemm2_weights_fp4[i] .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] .contiguous() ) permute_sf_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm2_scales_fp4_shuffled.append( nvfp4_block_scale_interleave( gemm2_scales_linear_fp4[i] .view(torch.uint8)[ permute_sf_indices.to(gemm2_scales_linear_fp4.device) ] .contiguous() ) ) # Stack weights for all experts gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) gemm1_scales_fp4_shuffled = ( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) ) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) gemm2_scales_fp4_shuffled = ( torch.stack(gemm2_scales_fp4_shuffled) .view(torch.float8_e4m3fn) .reshape(num_experts, hidden_size, intermediate_size // 16) ) return ( gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 processing gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data if self.allow_flashinfer: gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2 ) layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) # Common processing for w13_weight_scale_2 if not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): logger.warning_once( "w1_weight_scale_2 must match w3_weight_scale_2. " "Accuracy may be affected." ) w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False ) # GEMM 2 processing layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False ) # TensorRT-LLM specific processing if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): # Prepare static weights for TRT-LLM kernel # alternate: prepare_static_weight_layouts_for_trtllm_moe ( gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled, ) = self.prepare_static_weights_for_trtllm_fp4_moe( layer.w13_weight, layer.w2_weight, layer.w13_weight_scale, layer.w2_weight_scale, layer.w2_weight.size(-2), # hidden_size layer.w13_weight.size(-2) // 2, # intermediate_size layer.w13_weight.size(0), # num_experts ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") layer.gemm1_weights_fp4_shuffled = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) layer.gemm2_weights_fp4_shuffled = Parameter( gemm2_weights_fp4_shuffled, requires_grad=False ) layer.gemm1_scales_fp4_shuffled = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) layer.gemm2_scales_fp4_shuffled = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) # Additional parameter needed for TRT-LLM layer.g1_scale_c = Parameter( (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) # Clean up weights that won't be used by TRT-LLM del layer.w2_weight del layer.w2_weight_scale del layer.w13_weight del layer.w13_weight_scale elif self.use_marlin: # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) del layer.g1_alphas del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) layer.w13_weight_scale = Parameter( w13_blockscale_swizzled, requires_grad=False ) w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) layer.w2_weight_scale = Parameter( w2_blockscale_swizzled, requires_grad=False ) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if ( self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None return nvfp4_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, a2_gscale=layer.w2_input_scale_quant, ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: int | None = None, num_expert_group: int | None = None, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." ) assert activation == "silu", "Only SiLU activation is supported." if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): import flashinfer from vllm.model_executor.models.llama4 import Llama4MoE assert self.fused_experts is None a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( flashinfer.fp4_quantize( x, a1_gscale, is_sf_swizzled_layout=False, ) ) use_llama4_routing = ( custom_routing_function is Llama4MoE.custom_routing_function ) routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 if use_llama4_routing: routing_method_type = flashinfer.RoutingMethodType.Llama4 routing_bias = e_score_correction_bias if routing_bias is not None: routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits if use_llama4_routing else router_logits.to(torch.float32), routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( torch.float8_e4m3fn ), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( torch.float8_e4m3fn ), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, n_group=num_expert_group if num_expert_group is not None else 0, topk_group=topk_group if topk_group is not None else 0, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, tile_tokens_dim=_get_tile_tokens_dim( x.shape[0], top_k, layer.local_num_experts ), routing_method_type=routing_method_type, do_finalize=True, )[0] return out topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) # # Note: the order here is important. self.fused_experts can override # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or # trtllm. # if self.use_marlin: assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, None, None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, topk_weights, topk_ids, global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, workspace=layer.workspace, ) elif self.fused_experts is not None: assert ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ) assert is_valid_flashinfer_cutlass_fused_moe( x, layer.w13_weight, layer.w2_weight ), "Flashinfer CUTLASS Fused MoE not applicable!" return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) elif ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4, ) assert self.moe_quant_config is not None return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 assert self.moe_quant_config is not None return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], )