mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:16:00 +08:00
Consolidate Nvidia ModelOpt quant config handling for all quantization methods (#28076)
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
This commit is contained in:
parent
fcbcba6c70
commit
a8c536829c
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from fnmatch import fnmatch
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,7 +14,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
RoutingMethodType,
|
RoutingMethodType,
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
@ -86,45 +86,218 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
|||||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8Config(QuantizationConfig):
|
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
"""
|
||||||
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: "ModelOptQuantConfigBase"):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptQuantConfigBase(QuantizationConfig):
|
||||||
|
LinearMethodCls: type = LinearMethodBase
|
||||||
|
FusedMoEMethodCls: type = FusedMoEMethodBase
|
||||||
|
KVCacheMethodCls: type = BaseKVCacheMethod
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exclude_modules: list[str],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.exclude_modules: list[str] = exclude_modules
|
||||||
|
|
||||||
|
def is_layer_excluded(self, prefix: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a layer should be excluded from quantization.
|
||||||
|
|
||||||
|
Handles both exact matching (for fused layers) and ModelOpt wildcard matching.
|
||||||
|
|
||||||
|
The ModelOpt exclude_modules list is a list of wildcards.
|
||||||
|
"""
|
||||||
|
if len(self.exclude_modules) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# First check exact matching with fused layer support
|
||||||
|
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO: This special hard coded logic is not needed for quantized checkpoints
|
||||||
|
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
|
||||||
|
# exclude_modules config. But need to keep them for loading quantized
|
||||||
|
# checkpoints generated by older versions. Then check substring matching
|
||||||
|
# for patterns not caught by exact match
|
||||||
|
for exclude_module in self.exclude_modules:
|
||||||
|
# Skip exact matches already handled above
|
||||||
|
if exclude_module != prefix and (
|
||||||
|
exclude_module in prefix
|
||||||
|
or (
|
||||||
|
prefix.startswith("language_model.")
|
||||||
|
and exclude_module in prefix.removeprefix("language_model.")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# modelopt exclude modules are not simple strings, they are wildcards
|
||||||
|
for wildcard_pattern in self.exclude_modules:
|
||||||
|
if fnmatch(prefix, wildcard_pattern):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self, layer: torch.nn.Module, prefix: str
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||||
|
if isinstance(layer, Attention):
|
||||||
|
return self.KVCacheMethodCls(self)
|
||||||
|
|
||||||
|
# handle exclusion
|
||||||
|
if self.is_layer_excluded(prefix):
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: This special hard coded logic is not needed for quantized checkpoints
|
||||||
|
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
|
||||||
|
# exclude_modules config. But need to keep them for loading quantized
|
||||||
|
# checkpoints generated by older versions. Then check substring matching
|
||||||
|
# for patterns not caught by exact match
|
||||||
|
if "vision_tower" in prefix or "vision_model" in prefix:
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
|
# now, the layer is quantized, handle it here
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return self.LinearMethodCls(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||||
|
if len(self.exclude_modules) > 0:
|
||||||
|
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_config_filenames() -> list[str]:
|
||||||
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
quant_method: str,
|
||||||
|
kv_cache_quant_method: str | None,
|
||||||
|
exclude_modules: list[str],
|
||||||
|
original_config: dict[str, Any],
|
||||||
|
group_size: int | None,
|
||||||
|
) -> "ModelOptQuantConfigBase":
|
||||||
|
raise NotImplementedError("Please implement this function in sub classes")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
|
||||||
|
# Handle both ModelOpt format and compressed-tensors style format
|
||||||
|
if "quantization" in config:
|
||||||
|
# Traditional ModelOpt format:
|
||||||
|
# {"quantization": {"quant_algo": "..."}}
|
||||||
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||||
|
if not isinstance(quant_config, dict):
|
||||||
|
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
||||||
|
|
||||||
|
quant_method = quant_config.get("quant_algo")
|
||||||
|
|
||||||
|
# Handle kv_cache_quant_algo with proper type validation
|
||||||
|
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
|
||||||
|
|
||||||
|
# Handle group_size with proper type validation
|
||||||
|
group_size_raw = quant_config.get("group_size")
|
||||||
|
|
||||||
|
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||||
|
exclude_modules = quant_config.get("exclude_modules", [])
|
||||||
|
else:
|
||||||
|
# Compressed-tensors style format:
|
||||||
|
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||||
|
quant_method = config.get("quant_algo")
|
||||||
|
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||||
|
# "ignore" is the key in config.json
|
||||||
|
exclude_modules = config.get("ignore", [])
|
||||||
|
group_size_raw = config.get("group_size")
|
||||||
|
|
||||||
|
if not quant_method:
|
||||||
|
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||||
|
|
||||||
|
if kv_cache_quant_method is None:
|
||||||
|
# No KV cache quantization, keep this branch just to have this comment
|
||||||
|
pass
|
||||||
|
elif not isinstance(kv_cache_quant_method, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"kv_cache_quant_algo must be a string, got "
|
||||||
|
f"{type(kv_cache_quant_method)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(exclude_modules, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if group_size_raw is None:
|
||||||
|
group_size = None
|
||||||
|
elif isinstance(group_size_raw, int):
|
||||||
|
group_size = group_size_raw
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
group_size = int(group_size_raw)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ValueError(
|
||||||
|
f"group_size must be an integer, got {type(group_size_raw)}"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if quant_method not in QUANT_ALGOS:
|
||||||
|
raise ValueError(
|
||||||
|
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
||||||
|
"quantizations in vLLM. Please check the "
|
||||||
|
"`hf_quant_config.json` file for your model's "
|
||||||
|
"quant configuration."
|
||||||
|
)
|
||||||
|
return cls._from_config(
|
||||||
|
quant_method=quant_method,
|
||||||
|
kv_cache_quant_method=kv_cache_quant_method,
|
||||||
|
exclude_modules=exclude_modules,
|
||||||
|
group_size=group_size,
|
||||||
|
original_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||||
"""Config class for ModelOpt FP8."""
|
"""Config class for ModelOpt FP8."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
is_checkpoint_fp8_serialized: bool = False,
|
is_checkpoint_fp8_serialized: bool,
|
||||||
kv_cache_quant_method: str | None = None,
|
kv_cache_quant_method: str | None,
|
||||||
exclude_modules: list[str] | None = None,
|
exclude_modules: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(exclude_modules)
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
self.kv_cache_quant_method = kv_cache_quant_method
|
self.kv_cache_quant_method = kv_cache_quant_method
|
||||||
self.exclude_modules = exclude_modules or []
|
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected ModelOpt fp8 checkpoint. Please note that"
|
"Detected ModelOpt fp8 checkpoint. Please note that"
|
||||||
" the format is experimental and could change."
|
" the format is experimental and could change."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
def get_name(self) -> QuantizationMethods:
|
||||||
def get_name(cls) -> QuantizationMethods:
|
|
||||||
return "modelopt"
|
return "modelopt"
|
||||||
|
|
||||||
@classmethod
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
||||||
return [torch.bfloat16, torch.half]
|
return [torch.bfloat16, torch.half]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89
|
return 89
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_filenames(cls) -> list[str]:
|
|
||||||
return ["hf_quant_config.json"]
|
|
||||||
|
|
||||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
|
||||||
if self.exclude_modules is not None:
|
|
||||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(
|
def override_quantization_method(
|
||||||
cls, hf_quant_cfg, user_quant
|
cls, hf_quant_cfg, user_quant
|
||||||
@ -158,88 +331,19 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
def _from_config(
|
||||||
# Handle both ModelOpt format and compressed-tensors style format
|
cls,
|
||||||
if "quantization" in config:
|
*,
|
||||||
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
|
quant_method: str,
|
||||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
kv_cache_quant_method: str | None,
|
||||||
if not isinstance(quant_config, dict):
|
exclude_modules: list[str],
|
||||||
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
original_config: dict[str, Any],
|
||||||
quant_method = quant_config.get("quant_algo", "")
|
**kwargs: Any,
|
||||||
if not quant_method:
|
) -> "ModelOptFp8Config":
|
||||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
|
||||||
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
|
|
||||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
|
||||||
exclude_modules = quant_config.get("exclude_modules")
|
|
||||||
else:
|
|
||||||
# Compressed-tensors style format:
|
|
||||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
|
||||||
quant_method = config.get("quant_algo", "")
|
|
||||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
|
||||||
# "ignore" is the key in config.json
|
|
||||||
exclude_modules = config.get("ignore")
|
|
||||||
|
|
||||||
if quant_method not in QUANT_ALGOS:
|
|
||||||
raise ValueError(
|
|
||||||
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
|
||||||
"quantizations in vLLM. Please check the "
|
|
||||||
"`hf_quant_config.json` file for your model's "
|
|
||||||
"quant configuration."
|
|
||||||
)
|
|
||||||
is_checkpoint_fp8_serialized = "FP8" in quant_method
|
is_checkpoint_fp8_serialized = "FP8" in quant_method
|
||||||
|
|
||||||
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
|
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a layer should be excluded from quantization.
|
|
||||||
Handles both exact matching (for fused layers) and substring matching.
|
|
||||||
|
|
||||||
This method handles both regular models and multimodal models that use
|
|
||||||
the language_model prefix. For multimodal models, it checks if the
|
|
||||||
module name (without the language_model prefix) is in the exclude list.
|
|
||||||
"""
|
|
||||||
if self.exclude_modules is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# First check exact matching with fused layer support
|
|
||||||
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Then check substring matching for patterns not caught by exact match
|
|
||||||
for module in self.exclude_modules:
|
|
||||||
# Skip exact matches already handled above
|
|
||||||
if module != prefix and (
|
|
||||||
module in prefix
|
|
||||||
or (
|
|
||||||
prefix.startswith("language_model.")
|
|
||||||
and module in prefix.removeprefix("language_model.")
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_quant_method(
|
|
||||||
self, layer: torch.nn.Module, prefix: str
|
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
|
||||||
from vllm.attention.layer import ( # Avoid circular import
|
|
||||||
Attention,
|
|
||||||
MLAAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
if self.is_layer_excluded(prefix):
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
# Check if this is a vision model layer that should not be quantized
|
|
||||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
return ModelOptFp8LinearMethod(self)
|
|
||||||
elif isinstance(layer, (Attention, MLAAttention)):
|
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
|
||||||
elif isinstance(layer, FusedMoE):
|
|
||||||
return ModelOptFp8MoEMethod(self, layer)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Model Optimizer static quantization.
|
"""Linear method for Model Optimizer static quantization.
|
||||||
@ -344,7 +448,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: ModelOptFp8Config,
|
quant_config: ModelOptFp8Config,
|
||||||
layer: torch.nn.Module,
|
layer: FusedMoE,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(layer.moe_config)
|
super().__init__(layer.moe_config)
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
@ -686,7 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4Config(QuantizationConfig):
|
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
|
||||||
|
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
|
||||||
|
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||||
"""Config class for ModelOpt FP4."""
|
"""Config class for ModelOpt FP4."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -696,7 +805,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
exclude_modules: list[str],
|
exclude_modules: list[str],
|
||||||
group_size: int = 16,
|
group_size: int = 16,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(exclude_modules)
|
||||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||||
if is_checkpoint_nvfp4_serialized:
|
if is_checkpoint_nvfp4_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -706,28 +815,17 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
|
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||||
self.exclude_modules = exclude_modules
|
|
||||||
|
|
||||||
@classmethod
|
def get_name(self) -> QuantizationMethods:
|
||||||
def get_name(cls) -> QuantizationMethods:
|
|
||||||
return "modelopt_fp4"
|
return "modelopt_fp4"
|
||||||
|
|
||||||
@classmethod
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
||||||
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 80
|
return 80
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_filenames(cls) -> list[str]:
|
|
||||||
return ["hf_quant_config.json"]
|
|
||||||
|
|
||||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
|
||||||
if self.exclude_modules is not None:
|
|
||||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(
|
def override_quantization_method(
|
||||||
cls, hf_quant_cfg, user_quant
|
cls, hf_quant_cfg, user_quant
|
||||||
@ -761,105 +859,25 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
|
def _from_config(
|
||||||
# Handle both traditional ModelOpt format and compressed-tensors
|
cls,
|
||||||
# style format
|
*,
|
||||||
if "quantization" in config:
|
quant_method: str,
|
||||||
# Traditional ModelOpt format:
|
kv_cache_quant_method: str | None,
|
||||||
# {"quantization": {"quant_algo": "..."}}
|
exclude_modules: list[str],
|
||||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
original_config: dict[str, Any],
|
||||||
if not isinstance(quant_config, dict):
|
group_size: int | None,
|
||||||
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
**kwargs: Any,
|
||||||
|
) -> "ModelOptNvFp4Config":
|
||||||
quant_method = quant_config.get("quant_algo", "")
|
|
||||||
if not quant_method:
|
|
||||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
|
||||||
|
|
||||||
# Handle kv_cache_quant_algo with proper type validation
|
|
||||||
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
|
|
||||||
if kv_cache_quant_algo_raw is None:
|
|
||||||
# No KV cache quantization by default
|
|
||||||
kv_cache_quant_algo = None
|
|
||||||
elif isinstance(kv_cache_quant_algo_raw, str):
|
|
||||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"kv_cache_quant_algo must be a string, got "
|
|
||||||
f"{type(kv_cache_quant_algo_raw)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle group_size with proper type validation
|
|
||||||
group_size_raw = quant_config.get("group_size")
|
|
||||||
if group_size_raw is None:
|
|
||||||
group_size = 16 # Default value
|
|
||||||
elif isinstance(group_size_raw, int):
|
|
||||||
group_size = group_size_raw
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
group_size = int(group_size_raw)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise ValueError(
|
|
||||||
f"group_size must be an integer, got {type(group_size_raw)}"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
|
||||||
exclude_modules = quant_config.get("exclude_modules", [])
|
|
||||||
if not isinstance(exclude_modules, list):
|
|
||||||
raise ValueError(
|
|
||||||
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Compressed-tensors style format:
|
|
||||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
|
||||||
quant_method = config.get("quant_algo", "")
|
|
||||||
|
|
||||||
# Handle kv_cache_quant_algo with proper type validation
|
|
||||||
kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
|
|
||||||
if kv_cache_quant_algo_raw is None:
|
|
||||||
# No KV cache quantization by default
|
|
||||||
kv_cache_quant_algo = None
|
|
||||||
elif isinstance(kv_cache_quant_algo_raw, str):
|
|
||||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"kv_cache_quant_algo must be a string, got "
|
|
||||||
f"{type(kv_cache_quant_algo_raw)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle group_size with proper type validation
|
|
||||||
group_size_raw = config.get("group_size")
|
|
||||||
if group_size_raw is None:
|
|
||||||
group_size = 16 # Default value
|
|
||||||
elif isinstance(group_size_raw, int):
|
|
||||||
group_size = group_size_raw
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
group_size = int(group_size_raw)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise ValueError(
|
|
||||||
f"group_size must be an integer, got {type(group_size_raw)}"
|
|
||||||
) from None
|
|
||||||
|
|
||||||
# "ignore" is the key in config.json
|
|
||||||
exclude_modules = config.get("ignore", [])
|
|
||||||
if not isinstance(exclude_modules, list):
|
|
||||||
raise ValueError(
|
|
||||||
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if quant_method not in QUANT_ALGOS:
|
|
||||||
raise ValueError(
|
|
||||||
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
|
||||||
"quantizations in vLLM. Please check the "
|
|
||||||
"`hf_quant_config.json` file for your model's "
|
|
||||||
"quant configuration."
|
|
||||||
)
|
|
||||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||||
|
|
||||||
|
if group_size is None:
|
||||||
|
group_size = 16 # Default value
|
||||||
|
|
||||||
# For FP4, these fields are required
|
# For FP4, these fields are required
|
||||||
if is_checkpoint_nvfp4_serialized and "quantization" in config:
|
if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
|
||||||
# Check if required fields are present in the quantization config
|
# Check if required fields are present in the quantization config
|
||||||
quant_config = config["quantization"]
|
quant_config = original_config["quantization"]
|
||||||
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
|
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
|
||||||
missing_fields = [
|
missing_fields = [
|
||||||
field for field in required_fields if field not in quant_config
|
field for field in required_fields if field not in quant_config
|
||||||
@ -872,64 +890,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
is_checkpoint_nvfp4_serialized,
|
is_checkpoint_nvfp4_serialized,
|
||||||
kv_cache_quant_algo,
|
kv_cache_quant_method,
|
||||||
exclude_modules,
|
exclude_modules,
|
||||||
group_size,
|
group_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a layer should be excluded from quantization.
|
|
||||||
Handles both exact matching (for fused layers) and pattern matching.
|
|
||||||
"""
|
|
||||||
# First check exact matching with fused layer support
|
|
||||||
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check regex pattern matching for patterns not caught by exact match
|
|
||||||
import regex as re
|
|
||||||
|
|
||||||
for pattern in self.exclude_modules:
|
|
||||||
# Skip patterns that would be caught by exact matching
|
|
||||||
if "*" in pattern or "." in pattern:
|
|
||||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
|
||||||
if re.fullmatch(regex_str, prefix):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_quant_method(
|
|
||||||
self, layer: torch.nn.Module, prefix: str
|
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
|
||||||
from vllm.attention.layer import ( # Avoid circular import
|
|
||||||
Attention,
|
|
||||||
MLAAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
skip_layer = self.is_layer_excluded(prefix)
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
if skip_layer:
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
# Check if this is a vision model layer that should not be quantized
|
|
||||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
return ModelOptNvFp4LinearMethod(self)
|
|
||||||
elif isinstance(layer, (Attention, MLAAttention)):
|
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
|
||||||
elif isinstance(layer, FusedMoE):
|
|
||||||
if skip_layer:
|
|
||||||
return None
|
|
||||||
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
||||||
"""
|
|
||||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
|
|
||||||
super().__init__(quant_config)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Model Optimizer NVFP4.
|
"""Linear method for Model Optimizer NVFP4.
|
||||||
@ -1157,14 +1122,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: ModelOptNvFp4Config,
|
quant_config: ModelOptNvFp4Config,
|
||||||
moe: FusedMoEConfig,
|
layer: FusedMoE,
|
||||||
layer: torch.nn.Module,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
||||||
detect_nvfp4_moe_support, # noqa: E501
|
detect_nvfp4_moe_support, # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(moe)
|
super().__init__(layer.moe_config)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
@ -1802,3 +1766,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
k=x.shape[1],
|
k=x.shape[1],
|
||||||
e=layer.w13_weight.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||||
|
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
|
||||||
|
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user