diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index dedab33c1bdb7..6b5ed7762eb31 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from fnmatch import fnmatch from typing import TYPE_CHECKING, Any, Optional import torch @@ -13,7 +14,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEQuantConfig, RoutingMethodType, fp8_w8a8_moe_quant_config, @@ -86,45 +86,218 @@ QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] -class ModelOptFp8Config(QuantizationConfig): +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: "ModelOptQuantConfigBase"): + super().__init__(quant_config) + + +class ModelOptQuantConfigBase(QuantizationConfig): + LinearMethodCls: type = LinearMethodBase + FusedMoEMethodCls: type = FusedMoEMethodBase + KVCacheMethodCls: type = BaseKVCacheMethod + + def __init__( + self, + exclude_modules: list[str], + ): + super().__init__() + self.exclude_modules: list[str] = exclude_modules + + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + + Handles both exact matching (for fused layers) and ModelOpt wildcard matching. + + The ModelOpt exclude_modules list is a list of wildcards. + """ + if len(self.exclude_modules) == 0: + return False + + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): + return True + + # TODO: This special hard coded logic is not needed for quantized checkpoints + # generated by ModelOpt >= 0.39.0 where they are handled natually by the + # exclude_modules config. But need to keep them for loading quantized + # checkpoints generated by older versions. Then check substring matching + # for patterns not caught by exact match + for exclude_module in self.exclude_modules: + # Skip exact matches already handled above + if exclude_module != prefix and ( + exclude_module in prefix + or ( + prefix.startswith("language_model.") + and exclude_module in prefix.removeprefix("language_model.") + ) + ): + return True + + # modelopt exclude modules are not simple strings, they are wildcards + for wildcard_pattern in self.exclude_modules: + if fnmatch(prefix, wildcard_pattern): + return True + + return False + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # handle kv-cache first so we can focus only on weight quantization thereafter + if isinstance(layer, Attention): + return self.KVCacheMethodCls(self) + + # handle exclusion + if self.is_layer_excluded(prefix): + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + return None + + # TODO: This special hard coded logic is not needed for quantized checkpoints + # generated by ModelOpt >= 0.39.0 where they are handled natually by the + # exclude_modules config. But need to keep them for loading quantized + # checkpoints generated by older versions. Then check substring matching + # for patterns not caught by exact match + if "vision_tower" in prefix or "vision_model" in prefix: + return UnquantizedLinearMethod() + + # now, the layer is quantized, handle it here + if isinstance(layer, LinearBase): + return self.LinearMethodCls(self) + elif isinstance(layer, FusedMoE): + return self.FusedMoEMethodCls(quant_config=self, layer=layer) + + return None + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if len(self.exclude_modules) > 0: + self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) + + @staticmethod + def get_config_filenames() -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + ) -> "ModelOptQuantConfigBase": + raise NotImplementedError("Please implement this function in sub classes") + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase": + # Handle both ModelOpt format and compressed-tensors style format + if "quantization" in config: + # Traditional ModelOpt format: + # {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError("Expected 'quantization' to be a dictionary in config") + + quant_method = quant_config.get("quant_algo") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") + + # Handle group_size with proper type validation + group_size_raw = quant_config.get("group_size") + + # "exclude_modules" is the key in the legacy hf_quant_config.json + exclude_modules = quant_config.get("exclude_modules", []) + else: + # Compressed-tensors style format: + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo") + kv_cache_quant_method = config.get("kv_cache_quant_algo") + # "ignore" is the key in config.json + exclude_modules = config.get("ignore", []) + group_size_raw = config.get("group_size") + + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + + if kv_cache_quant_method is None: + # No KV cache quantization, keep this branch just to have this comment + pass + elif not isinstance(kv_cache_quant_method, str): + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_method)}" + ) + + if not isinstance(exclude_modules, list): + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) + + if group_size_raw is None: + group_size = None + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None + + if quant_method not in QUANT_ALGOS: + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration." + ) + return cls._from_config( + quant_method=quant_method, + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=exclude_modules, + group_size=group_size, + original_config=config, + ) + + +class ModelOptFp8Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP8.""" def __init__( self, - is_checkpoint_fp8_serialized: bool = False, - kv_cache_quant_method: str | None = None, - exclude_modules: list[str] | None = None, + is_checkpoint_fp8_serialized: bool, + kv_cache_quant_method: str | None, + exclude_modules: list[str], ) -> None: - super().__init__() + super().__init__(exclude_modules) self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change." ) - @classmethod - def get_name(cls) -> QuantizationMethods: + def get_name(self) -> QuantizationMethods: return "modelopt" - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["hf_quant_config.json"] - - def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - if self.exclude_modules is not None: - self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) - @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant @@ -158,88 +331,19 @@ class ModelOptFp8Config(QuantizationConfig): return None @classmethod - def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": - # Handle both ModelOpt format and compressed-tensors style format - if "quantization" in config: - # ModelOpt format: {"quantization": {"quant_algo": "..."}} - quant_config = cls.get_from_keys(config, ["quantization"]) - if not isinstance(quant_config, dict): - raise ValueError("Expected 'quantization' to be a dictionary in config") - quant_method = quant_config.get("quant_algo", "") - if not quant_method: - raise ValueError("Missing 'quant_algo' in quantization config") - kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") - # "exclude_modules" is the key in the legacy hf_quant_config.json - exclude_modules = quant_config.get("exclude_modules") - else: - # Compressed-tensors style format: - # {"quant_algo": "...", "quant_method": "modelopt"} - quant_method = config.get("quant_algo", "") - kv_cache_quant_method = config.get("kv_cache_quant_algo") - # "ignore" is the key in config.json - exclude_modules = config.get("ignore") - - if quant_method not in QUANT_ALGOS: - raise ValueError( - f"ModelOpt currently only supports: {QUANT_ALGOS} " - "quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration." - ) + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + **kwargs: Any, + ) -> "ModelOptFp8Config": is_checkpoint_fp8_serialized = "FP8" in quant_method return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) - def is_layer_excluded(self, prefix: str) -> bool: - """ - Check if a layer should be excluded from quantization. - Handles both exact matching (for fused layers) and substring matching. - - This method handles both regular models and multimodal models that use - the language_model prefix. For multimodal models, it checks if the - module name (without the language_model prefix) is in the exclude list. - """ - if self.exclude_modules is None: - return False - - # First check exact matching with fused layer support - if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): - return True - - # Then check substring matching for patterns not caught by exact match - for module in self.exclude_modules: - # Skip exact matches already handled above - if module != prefix and ( - module in prefix - or ( - prefix.startswith("language_model.") - and module in prefix.removeprefix("language_model.") - ) - ): - return True - return False - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import ( # Avoid circular import - Attention, - MLAAttention, - ) - - if isinstance(layer, LinearBase): - if self.is_layer_excluded(prefix): - return UnquantizedLinearMethod() - # Check if this is a vision model layer that should not be quantized - if "vision_tower" in prefix or "vision_model" in prefix: - return UnquantizedLinearMethod() - return ModelOptFp8LinearMethod(self) - elif isinstance(layer, (Attention, MLAAttention)): - return ModelOptFp8KVCacheMethod(self) - elif isinstance(layer, FusedMoE): - return ModelOptFp8MoEMethod(self, layer) - return None - class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. @@ -344,7 +448,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - layer: torch.nn.Module, + layer: FusedMoE, ) -> None: super().__init__(layer.moe_config) self.layer = layer @@ -686,7 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) -class ModelOptNvFp4Config(QuantizationConfig): +ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod +ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod +ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod + + +class ModelOptNvFp4Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP4.""" def __init__( @@ -696,7 +805,7 @@ class ModelOptNvFp4Config(QuantizationConfig): exclude_modules: list[str], group_size: int = 16, ) -> None: - super().__init__() + super().__init__(exclude_modules) self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( @@ -706,28 +815,17 @@ class ModelOptNvFp4Config(QuantizationConfig): self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo - self.exclude_modules = exclude_modules - @classmethod - def get_name(cls) -> QuantizationMethods: + def get_name(self) -> QuantizationMethods: return "modelopt_fp4" - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: return 80 - @classmethod - def get_config_filenames(cls) -> list[str]: - return ["hf_quant_config.json"] - - def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - if self.exclude_modules is not None: - self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) - @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant @@ -761,105 +859,25 @@ class ModelOptNvFp4Config(QuantizationConfig): return None @classmethod - def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": - # Handle both traditional ModelOpt format and compressed-tensors - # style format - if "quantization" in config: - # Traditional ModelOpt format: - # {"quantization": {"quant_algo": "..."}} - quant_config = cls.get_from_keys(config, ["quantization"]) - if not isinstance(quant_config, dict): - raise ValueError("Expected 'quantization' to be a dictionary in config") - - quant_method = quant_config.get("quant_algo", "") - if not quant_method: - raise ValueError("Missing 'quant_algo' in quantization config") - - # Handle kv_cache_quant_algo with proper type validation - kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") - if kv_cache_quant_algo_raw is None: - # No KV cache quantization by default - kv_cache_quant_algo = None - elif isinstance(kv_cache_quant_algo_raw, str): - kv_cache_quant_algo = kv_cache_quant_algo_raw - else: - raise ValueError( - f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}" - ) - - # Handle group_size with proper type validation - group_size_raw = quant_config.get("group_size") - if group_size_raw is None: - group_size = 16 # Default value - elif isinstance(group_size_raw, int): - group_size = group_size_raw - else: - try: - group_size = int(group_size_raw) - except (ValueError, TypeError): - raise ValueError( - f"group_size must be an integer, got {type(group_size_raw)}" - ) from None - - # "exclude_modules" is the key in the legacy hf_quant_config.json - exclude_modules = quant_config.get("exclude_modules", []) - if not isinstance(exclude_modules, list): - raise ValueError( - f"exclude_modules must be a list, got {type(exclude_modules)}" - ) - else: - # Compressed-tensors style format: - # {"quant_algo": "...", "quant_method": "modelopt"} - quant_method = config.get("quant_algo", "") - - # Handle kv_cache_quant_algo with proper type validation - kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo") - if kv_cache_quant_algo_raw is None: - # No KV cache quantization by default - kv_cache_quant_algo = None - elif isinstance(kv_cache_quant_algo_raw, str): - kv_cache_quant_algo = kv_cache_quant_algo_raw - else: - raise ValueError( - f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}" - ) - - # Handle group_size with proper type validation - group_size_raw = config.get("group_size") - if group_size_raw is None: - group_size = 16 # Default value - elif isinstance(group_size_raw, int): - group_size = group_size_raw - else: - try: - group_size = int(group_size_raw) - except (ValueError, TypeError): - raise ValueError( - f"group_size must be an integer, got {type(group_size_raw)}" - ) from None - - # "ignore" is the key in config.json - exclude_modules = config.get("ignore", []) - if not isinstance(exclude_modules, list): - raise ValueError( - f"exclude_modules must be a list, got {type(exclude_modules)}" - ) - - if quant_method not in QUANT_ALGOS: - raise ValueError( - f"ModelOpt currently only supports: {QUANT_ALGOS} " - "quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration." - ) + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + **kwargs: Any, + ) -> "ModelOptNvFp4Config": is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + if group_size is None: + group_size = 16 # Default value + # For FP4, these fields are required - if is_checkpoint_nvfp4_serialized and "quantization" in config: + if is_checkpoint_nvfp4_serialized and "quantization" in original_config: # Check if required fields are present in the quantization config - quant_config = config["quantization"] + quant_config = original_config["quantization"] required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config @@ -872,64 +890,11 @@ class ModelOptNvFp4Config(QuantizationConfig): return cls( is_checkpoint_nvfp4_serialized, - kv_cache_quant_algo, + kv_cache_quant_method, exclude_modules, group_size, ) - def is_layer_excluded(self, prefix: str) -> bool: - """ - Check if a layer should be excluded from quantization. - Handles both exact matching (for fused layers) and pattern matching. - """ - # First check exact matching with fused layer support - if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): - return True - - # Check regex pattern matching for patterns not caught by exact match - import regex as re - - for pattern in self.exclude_modules: - # Skip patterns that would be caught by exact matching - if "*" in pattern or "." in pattern: - regex_str = pattern.replace(".", r"\.").replace("*", r".*") - if re.fullmatch(regex_str, prefix): - return True - return False - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import ( # Avoid circular import - Attention, - MLAAttention, - ) - - skip_layer = self.is_layer_excluded(prefix) - if isinstance(layer, LinearBase): - if skip_layer: - return UnquantizedLinearMethod() - # Check if this is a vision model layer that should not be quantized - if "vision_tower" in prefix or "vision_model" in prefix: - return UnquantizedLinearMethod() - return ModelOptNvFp4LinearMethod(self) - elif isinstance(layer, (Attention, MLAAttention)): - return ModelOptFp8KVCacheMethod(self) - elif isinstance(layer, FusedMoE): - if skip_layer: - return None - return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) - return None - - -class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): - """ - Supports loading kv-cache scaling factors from FP8 checkpoints. - """ - - def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config): - super().__init__(quant_config) - class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. @@ -1157,14 +1122,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptNvFp4Config, - moe: FusedMoEConfig, - layer: torch.nn.Module, + layer: FusedMoE, ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( detect_nvfp4_moe_support, # noqa: E501 ) - super().__init__(moe) + super().__init__(layer.moe_config) self.quant_config = quant_config self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) @@ -1802,3 +1766,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): k=x.shape[1], e=layer.w13_weight.shape[0], ) + + +ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod +ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE +ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod