mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
Consolidate Nvidia ModelOpt quant config handling for all quantization methods (#28076)
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
This commit is contained in:
parent
fcbcba6c70
commit
a8c536829c
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -13,7 +14,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
@ -86,45 +86,218 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: "ModelOptQuantConfigBase"):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
LinearMethodCls: type = LinearMethodBase
|
||||
FusedMoEMethodCls: type = FusedMoEMethodBase
|
||||
KVCacheMethodCls: type = BaseKVCacheMethod
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exclude_modules: list[str],
|
||||
):
|
||||
super().__init__()
|
||||
self.exclude_modules: list[str] = exclude_modules
|
||||
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
|
||||
Handles both exact matching (for fused layers) and ModelOpt wildcard matching.
|
||||
|
||||
The ModelOpt exclude_modules list is a list of wildcards.
|
||||
"""
|
||||
if len(self.exclude_modules) == 0:
|
||||
return False
|
||||
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# TODO: This special hard coded logic is not needed for quantized checkpoints
|
||||
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
|
||||
# exclude_modules config. But need to keep them for loading quantized
|
||||
# checkpoints generated by older versions. Then check substring matching
|
||||
# for patterns not caught by exact match
|
||||
for exclude_module in self.exclude_modules:
|
||||
# Skip exact matches already handled above
|
||||
if exclude_module != prefix and (
|
||||
exclude_module in prefix
|
||||
or (
|
||||
prefix.startswith("language_model.")
|
||||
and exclude_module in prefix.removeprefix("language_model.")
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
# modelopt exclude modules are not simple strings, they are wildcards
|
||||
for wildcard_pattern in self.exclude_modules:
|
||||
if fnmatch(prefix, wildcard_pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
# handle exclusion
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
# TODO: This special hard coded logic is not needed for quantized checkpoints
|
||||
# generated by ModelOpt >= 0.39.0 where they are handled natually by the
|
||||
# exclude_modules config. But need to keep them for loading quantized
|
||||
# checkpoints generated by older versions. Then check substring matching
|
||||
# for patterns not caught by exact match
|
||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
# now, the layer is quantized, handle it here
|
||||
if isinstance(layer, LinearBase):
|
||||
return self.LinearMethodCls(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
|
||||
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if len(self.exclude_modules) > 0:
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
group_size: int | None,
|
||||
) -> "ModelOptQuantConfigBase":
|
||||
raise NotImplementedError("Please implement this function in sub classes")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase":
|
||||
# Handle both ModelOpt format and compressed-tensors style format
|
||||
if "quantization" in config:
|
||||
# Traditional ModelOpt format:
|
||||
# {"quantization": {"quant_algo": "..."}}
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
if not isinstance(quant_config, dict):
|
||||
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
||||
|
||||
quant_method = quant_config.get("quant_algo")
|
||||
|
||||
# Handle kv_cache_quant_algo with proper type validation
|
||||
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
|
||||
|
||||
# Handle group_size with proper type validation
|
||||
group_size_raw = quant_config.get("group_size")
|
||||
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo")
|
||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore", [])
|
||||
group_size_raw = config.get("group_size")
|
||||
|
||||
if not quant_method:
|
||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||
|
||||
if kv_cache_quant_method is None:
|
||||
# No KV cache quantization, keep this branch just to have this comment
|
||||
pass
|
||||
elif not isinstance(kv_cache_quant_method, str):
|
||||
raise ValueError(
|
||||
f"kv_cache_quant_algo must be a string, got "
|
||||
f"{type(kv_cache_quant_method)}"
|
||||
)
|
||||
|
||||
if not isinstance(exclude_modules, list):
|
||||
raise ValueError(
|
||||
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
||||
)
|
||||
|
||||
if group_size_raw is None:
|
||||
group_size = None
|
||||
elif isinstance(group_size_raw, int):
|
||||
group_size = group_size_raw
|
||||
else:
|
||||
try:
|
||||
group_size = int(group_size_raw)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(
|
||||
f"group_size must be an integer, got {type(group_size_raw)}"
|
||||
) from None
|
||||
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
raise ValueError(
|
||||
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
||||
"quantizations in vLLM. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration."
|
||||
)
|
||||
return cls._from_config(
|
||||
quant_method=quant_method,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
group_size=group_size,
|
||||
original_config=config,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
kv_cache_quant_method: str | None = None,
|
||||
exclude_modules: list[str] | None = None,
|
||||
is_checkpoint_fp8_serialized: bool,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
super().__init__(exclude_modules)
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.exclude_modules = exclude_modules or []
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
" the format is experimental and could change."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.exclude_modules is not None:
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
@ -158,88 +331,19 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
||||
# Handle both ModelOpt format and compressed-tensors style format
|
||||
if "quantization" in config:
|
||||
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
if not isinstance(quant_config, dict):
|
||||
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
||||
quant_method = quant_config.get("quant_algo", "")
|
||||
if not quant_method:
|
||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||
kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules")
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo", "")
|
||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore")
|
||||
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
raise ValueError(
|
||||
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
||||
"quantizations in vLLM. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration."
|
||||
)
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptFp8Config":
|
||||
is_checkpoint_fp8_serialized = "FP8" in quant_method
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules)
|
||||
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and substring matching.
|
||||
|
||||
This method handles both regular models and multimodal models that use
|
||||
the language_model prefix. For multimodal models, it checks if the
|
||||
module name (without the language_model prefix) is in the exclude list.
|
||||
"""
|
||||
if self.exclude_modules is None:
|
||||
return False
|
||||
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Then check substring matching for patterns not caught by exact match
|
||||
for module in self.exclude_modules:
|
||||
# Skip exact matches already handled above
|
||||
if module != prefix and (
|
||||
module in prefix
|
||||
or (
|
||||
prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model.")
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import ( # Avoid circular import
|
||||
Attention,
|
||||
MLAAttention,
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
elif isinstance(layer, (Attention, MLAAttention)):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self, layer)
|
||||
return None
|
||||
|
||||
|
||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Model Optimizer static quantization.
|
||||
@ -344,7 +448,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptFp8Config,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
) -> None:
|
||||
super().__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
@ -686,7 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class ModelOptNvFp4Config(QuantizationConfig):
|
||||
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
|
||||
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
|
||||
ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt FP4."""
|
||||
|
||||
def __init__(
|
||||
@ -696,7 +805,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
exclude_modules: list[str],
|
||||
group_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
super().__init__(exclude_modules)
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
if is_checkpoint_nvfp4_serialized:
|
||||
logger.warning(
|
||||
@ -706,28 +815,17 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
|
||||
self.group_size = group_size
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt_fp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.exclude_modules is not None:
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
@ -761,105 +859,25 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
# Handle both traditional ModelOpt format and compressed-tensors
|
||||
# style format
|
||||
if "quantization" in config:
|
||||
# Traditional ModelOpt format:
|
||||
# {"quantization": {"quant_algo": "..."}}
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
if not isinstance(quant_config, dict):
|
||||
raise ValueError("Expected 'quantization' to be a dictionary in config")
|
||||
|
||||
quant_method = quant_config.get("quant_algo", "")
|
||||
if not quant_method:
|
||||
raise ValueError("Missing 'quant_algo' in quantization config")
|
||||
|
||||
# Handle kv_cache_quant_algo with proper type validation
|
||||
kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
|
||||
if kv_cache_quant_algo_raw is None:
|
||||
# No KV cache quantization by default
|
||||
kv_cache_quant_algo = None
|
||||
elif isinstance(kv_cache_quant_algo_raw, str):
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
else:
|
||||
raise ValueError(
|
||||
f"kv_cache_quant_algo must be a string, got "
|
||||
f"{type(kv_cache_quant_algo_raw)}"
|
||||
)
|
||||
|
||||
# Handle group_size with proper type validation
|
||||
group_size_raw = quant_config.get("group_size")
|
||||
if group_size_raw is None:
|
||||
group_size = 16 # Default value
|
||||
elif isinstance(group_size_raw, int):
|
||||
group_size = group_size_raw
|
||||
else:
|
||||
try:
|
||||
group_size = int(group_size_raw)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(
|
||||
f"group_size must be an integer, got {type(group_size_raw)}"
|
||||
) from None
|
||||
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
if not isinstance(exclude_modules, list):
|
||||
raise ValueError(
|
||||
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
||||
)
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo", "")
|
||||
|
||||
# Handle kv_cache_quant_algo with proper type validation
|
||||
kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
|
||||
if kv_cache_quant_algo_raw is None:
|
||||
# No KV cache quantization by default
|
||||
kv_cache_quant_algo = None
|
||||
elif isinstance(kv_cache_quant_algo_raw, str):
|
||||
kv_cache_quant_algo = kv_cache_quant_algo_raw
|
||||
else:
|
||||
raise ValueError(
|
||||
f"kv_cache_quant_algo must be a string, got "
|
||||
f"{type(kv_cache_quant_algo_raw)}"
|
||||
)
|
||||
|
||||
# Handle group_size with proper type validation
|
||||
group_size_raw = config.get("group_size")
|
||||
if group_size_raw is None:
|
||||
group_size = 16 # Default value
|
||||
elif isinstance(group_size_raw, int):
|
||||
group_size = group_size_raw
|
||||
else:
|
||||
try:
|
||||
group_size = int(group_size_raw)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(
|
||||
f"group_size must be an integer, got {type(group_size_raw)}"
|
||||
) from None
|
||||
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore", [])
|
||||
if not isinstance(exclude_modules, list):
|
||||
raise ValueError(
|
||||
f"exclude_modules must be a list, got {type(exclude_modules)}"
|
||||
)
|
||||
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
raise ValueError(
|
||||
f"ModelOpt currently only supports: {QUANT_ALGOS} "
|
||||
"quantizations in vLLM. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration."
|
||||
)
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
group_size: int | None,
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptNvFp4Config":
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
|
||||
if group_size is None:
|
||||
group_size = 16 # Default value
|
||||
|
||||
# For FP4, these fields are required
|
||||
if is_checkpoint_nvfp4_serialized and "quantization" in config:
|
||||
if is_checkpoint_nvfp4_serialized and "quantization" in original_config:
|
||||
# Check if required fields are present in the quantization config
|
||||
quant_config = config["quantization"]
|
||||
quant_config = original_config["quantization"]
|
||||
required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"]
|
||||
missing_fields = [
|
||||
field for field in required_fields if field not in quant_config
|
||||
@ -872,64 +890,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized,
|
||||
kv_cache_quant_algo,
|
||||
kv_cache_quant_method,
|
||||
exclude_modules,
|
||||
group_size,
|
||||
)
|
||||
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and pattern matching.
|
||||
"""
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Check regex pattern matching for patterns not caught by exact match
|
||||
import regex as re
|
||||
|
||||
for pattern in self.exclude_modules:
|
||||
# Skip patterns that would be caught by exact matching
|
||||
if "*" in pattern or "." in pattern:
|
||||
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import ( # Avoid circular import
|
||||
Attention,
|
||||
MLAAttention,
|
||||
)
|
||||
|
||||
skip_layer = self.is_layer_excluded(prefix)
|
||||
if isinstance(layer, LinearBase):
|
||||
if skip_layer:
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if "vision_tower" in prefix or "vision_model" in prefix:
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, (Attention, MLAAttention)):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if skip_layer:
|
||||
return None
|
||||
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
||||
return None
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Model Optimizer NVFP4.
|
||||
@ -1157,14 +1122,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptNvFp4Config,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
|
||||
detect_nvfp4_moe_support, # noqa: E501
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
super().__init__(layer.moe_config)
|
||||
self.quant_config = quant_config
|
||||
self.layer = layer
|
||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||
@ -1802,3 +1766,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
|
||||
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user