Consolidate Nvidia ModelOpt quant config handling for all quantization methods (#28076)

Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
This commit is contained in:
Shengliang Xu 2025-11-19 19:39:36 -08:00 committed by GitHub
parent fcbcba6c70
commit a8c536829c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -13,7 +14,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
@ -86,45 +86,218 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
class ModelOptFp8Config(QuantizationConfig):
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: "ModelOptQuantConfigBase"):
super().__init__(quant_config)
class ModelOptQuantConfigBase(QuantizationConfig):
LinearMethodCls: type = LinearMethodBase
FusedMoEMethodCls: type = FusedMoEMethodBase
KVCacheMethodCls: type = BaseKVCacheMethod
def __init__(
self,
exclude_modules: list[str],
):
super().__init__()
self.exclude_modules: list[str] = exclude_modules
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and ModelOpt wildcard matching.
The ModelOpt exclude_modules list is a list of wildcards.
"""
if len(self.exclude_modules) == 0:
return False
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
return True
# TODO: This special hard coded logic is not needed for quantized checkpoints
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
# exclude_modules config. But need to keep them for loading quantized
# checkpoints generated by older versions. Then check substring matching
# for patterns not caught by exact match
for exclude_module in self.exclude_modules:
# Skip exact matches already handled above
if exclude_module != prefix and (
exclude_module in prefix
or (
prefix.startswith("language_model.")
and exclude_module in prefix.removeprefix("language_model.")
)
):
return True
# modelopt exclude modules are not simple strings, they are wildcards
for wildcard_pattern in self.exclude_modules:
if fnmatch(prefix, wildcard_pattern):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)
# handle exclusion
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return None
# TODO: This special hard coded logic is not needed for quantized checkpoints
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
# exclude_modules config. But need to keep them for loading quantized
# checkpoints generated by older versions. Then check substring matching
# for patterns not caught by exact match
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
# now, the layer is quantized, handle it here
if isinstance(layer, LinearBase):
return self.LinearMethodCls(self)
elif isinstance(layer, FusedMoE):
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if len(self.exclude_modules) > 0:
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
@staticmethod
def get_config_filenames() -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
group_size: int | None,
) -> "ModelOptQuantConfigBase":
raise NotImplementedError("Please implement this function in sub classes")
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
# Handle both ModelOpt format and compressed-tensors style format
if "quantization" in config:
# Traditional ModelOpt format:
# {"quantization": {"quant_algo": "..."}}
quant_config = cls.get_from_keys(config, ["quantization"])
if not isinstance(quant_config, dict):
raise ValueError("Expected 'quantization' to be a dictionary in config")
quant_method = quant_config.get("quant_algo")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
# Handle group_size with proper type validation
group_size_raw = quant_config.get("group_size")
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
else:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
# "ignore" is the key in config.json
exclude_modules = config.get("ignore", [])
group_size_raw = config.get("group_size")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
if kv_cache_quant_method is None:
# No KV cache quantization, keep this branch just to have this comment
pass
elif not isinstance(kv_cache_quant_method, str):
raise ValueError(
f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_method)}"
)
if not isinstance(exclude_modules, list):
raise ValueError(
f"exclude_modules must be a list, got {type(exclude_modules)}"
)
if group_size_raw is None:
group_size = None
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(
f"group_size must be an integer, got {type(group_size_raw)}"
) from None
if quant_method not in QUANT_ALGOS:
raise ValueError(
f"ModelOpt currently only supports: {QUANT_ALGOS} "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
return cls._from_config(
quant_method=quant_method,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
group_size=group_size,
original_config=config,
)
class ModelOptFp8Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: str | None = None,
exclude_modules: list[str] | None = None,
is_checkpoint_fp8_serialized: bool,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
) -> None:
super().__init__()
super().__init__(exclude_modules)
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules or []
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
@classmethod
def get_name(cls) -> QuantizationMethods:
def get_name(self) -> QuantizationMethods:
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.exclude_modules is not None:
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
@ -158,88 +331,19 @@ class ModelOptFp8Config(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
# Handle both ModelOpt format and compressed-tensors style format
if "quantization" in config:
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
quant_config = cls.get_from_keys(config, ["quantization"])
if not isinstance(quant_config, dict):
raise ValueError("Expected 'quantization' to be a dictionary in config")
quant_method = quant_config.get("quant_algo", "")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules")
else:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo", "")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
# "ignore" is the key in config.json
exclude_modules = config.get("ignore")
if quant_method not in QUANT_ALGOS:
raise ValueError(
f"ModelOpt currently only supports: {QUANT_ALGOS} "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
**kwargs: Any,
) -> "ModelOptFp8Config":
is_checkpoint_fp8_serialized = "FP8" in quant_method
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and substring matching.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
module name (without the language_model prefix) is in the exclude list.
"""
if self.exclude_modules is None:
return False
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
return True
# Then check substring matching for patterns not caught by exact match
for module in self.exclude_modules:
# Skip exact matches already handled above
if module != prefix and (
module in prefix
or (
prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")
)
):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self, layer)
return None
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
@ -344,7 +448,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptFp8Config,
layer: torch.nn.Module,
layer: FusedMoE,
) -> None:
super().__init__(layer.moe_config)
self.layer = layer
@ -686,7 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
class ModelOptNvFp4Config(QuantizationConfig):
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
class ModelOptNvFp4Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt FP4."""
def __init__(
@ -696,7 +805,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules: list[str],
group_size: int = 16,
) -> None:
super().__init__()
super().__init__(exclude_modules)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
@ -706,28 +815,17 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> QuantizationMethods:
def get_name(self) -> QuantizationMethods:
return "modelopt_fp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.exclude_modules is not None:
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
@ -761,105 +859,25 @@ class ModelOptNvFp4Config(QuantizationConfig):
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
# Handle both traditional ModelOpt format and compressed-tensors
# style format
if "quantization" in config:
# Traditional ModelOpt format:
# {"quantization": {"quant_algo": "..."}}
quant_config = cls.get_from_keys(config, ["quantization"])
if not isinstance(quant_config, dict):
raise ValueError("Expected 'quantization' to be a dictionary in config")
quant_method = quant_config.get("quant_algo", "")
if not quant_method:
raise ValueError("Missing 'quant_algo' in quantization config")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
if kv_cache_quant_algo_raw is None:
# No KV cache quantization by default
kv_cache_quant_algo = None
elif isinstance(kv_cache_quant_algo_raw, str):
kv_cache_quant_algo = kv_cache_quant_algo_raw
else:
raise ValueError(
f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_algo_raw)}"
)
# Handle group_size with proper type validation
group_size_raw = quant_config.get("group_size")
if group_size_raw is None:
group_size = 16 # Default value
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(
f"group_size must be an integer, got {type(group_size_raw)}"
) from None
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
if not isinstance(exclude_modules, list):
raise ValueError(
f"exclude_modules must be a list, got {type(exclude_modules)}"
)
else:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo", "")
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
if kv_cache_quant_algo_raw is None:
# No KV cache quantization by default
kv_cache_quant_algo = None
elif isinstance(kv_cache_quant_algo_raw, str):
kv_cache_quant_algo = kv_cache_quant_algo_raw
else:
raise ValueError(
f"kv_cache_quant_algo must be a string, got "
f"{type(kv_cache_quant_algo_raw)}"
)
# Handle group_size with proper type validation
group_size_raw = config.get("group_size")
if group_size_raw is None:
group_size = 16 # Default value
elif isinstance(group_size_raw, int):
group_size = group_size_raw
else:
try:
group_size = int(group_size_raw)
except (ValueError, TypeError):
raise ValueError(
f"group_size must be an integer, got {type(group_size_raw)}"
) from None
# "ignore" is the key in config.json
exclude_modules = config.get("ignore", [])
if not isinstance(exclude_modules, list):
raise ValueError(
f"exclude_modules must be a list, got {type(exclude_modules)}"
)
if quant_method not in QUANT_ALGOS:
raise ValueError(
f"ModelOpt currently only supports: {QUANT_ALGOS} "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
group_size: int | None,
**kwargs: Any,
) -> "ModelOptNvFp4Config":
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
if group_size is None:
group_size = 16 # Default value
# For FP4, these fields are required
if is_checkpoint_nvfp4_serialized and "quantization" in config:
if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
# Check if required fields are present in the quantization config
quant_config = config["quantization"]
quant_config = original_config["quantization"]
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
missing_fields = [
field for field in required_fields if field not in quant_config
@ -872,64 +890,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(
is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo,
kv_cache_quant_method,
exclude_modules,
group_size,
)
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and pattern matching.
"""
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
return True
# Check regex pattern matching for patterns not caught by exact match
import regex as re
for pattern in self.exclude_modules:
# Skip patterns that would be caught by exact matching
if "*" in pattern or "." in pattern:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import ( # Avoid circular import
Attention,
MLAAttention,
)
skip_layer = self.is_layer_excluded(prefix)
if isinstance(layer, LinearBase):
if skip_layer:
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if "vision_tower" in prefix or "vision_model" in prefix:
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, (Attention, MLAAttention)):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
if skip_layer:
return None
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
super().__init__(quant_config)
class ModelOptNvFp4LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer NVFP4.
@ -1157,14 +1122,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def __init__(
self,
quant_config: ModelOptNvFp4Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
layer: FusedMoE,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
detect_nvfp4_moe_support, # noqa: E501
)
super().__init__(moe)
super().__init__(layer.moe_config)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
@ -1802,3 +1766,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
k=x.shape[1],
e=layer.w13_weight.shape[0],
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod