[Quant] [Bugfix] Fix quantization config matching with hf_to_vllm_mapper (#20046)

This commit is contained in:
Kyle Sayers 2025-07-01 06:20:34 -04:00 committed by GitHub
parent c05596f1a3
commit 9025a9a705
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 107 additions and 29 deletions

View File

@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig):
def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization config."""
super().__init__()
self.num_bits = num_bits
def get_name(self) -> QuantizationMethods:

View File

@ -805,7 +805,7 @@ def create_lora_manager(
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "packed_modules_mapping"):
if not isinstance(model, SupportsLoRA):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,

View File

@ -111,10 +111,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,

View File

@ -10,6 +10,7 @@ from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str
@ -149,3 +150,15 @@ class QuantizationConfig(ABC):
def get_cache_scale(self, name: str) -> Optional[str]:
return None
def apply_vllm_mapper( # noqa: B027
self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass

View File

@ -63,6 +63,7 @@ class BitBLASConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act = False
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from typing import Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import (CompressionFormat,
@ -37,6 +37,9 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
cutlass_fp4_supported)
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
@ -80,6 +83,18 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
self.sparsity_scheme_map)
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
self.sparsity_ignore_list)
if self.kv_cache_scheme is not None:
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
self.kv_cache_scheme)
def get_quant_method(
self,
layer: torch.nn.Module,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -39,6 +39,9 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
@ -100,6 +103,11 @@ class Fp8Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.ignored_layers is not None:
self.ignored_layers = hf_to_vllm_mapper.apply_list(
self.ignored_layers)
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])

View File

@ -81,6 +81,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act = False
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act

View File

@ -32,6 +32,8 @@ class MarlinConfig(QuantizationConfig):
group_size: int,
lm_head_quantized: bool,
) -> None:
super().__init__()
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized

View File

@ -181,6 +181,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
exclude_modules: list[str],
group_size: int = 16,
) -> None:
super().__init__()
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(

View File

@ -55,6 +55,7 @@ class TorchAOConfig(QuantizationConfig):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
super().__init__()
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []

View File

@ -24,6 +24,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping

View File

@ -18,6 +18,7 @@ from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool:
class SupportsQuant:
"""The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> Self:
instance = super().__new__(cls)
# find config passed in arguments
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
# attach config to model for general use
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)
# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)
return instance
@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
"""Find quant config passed through model constructor args"""
from vllm.config import VllmConfig # avoid circular import
args_values = list(args) + list(kwargs.values())

View File

@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
SupportsMultiModal, SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
@ -821,7 +821,8 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
info=Qwen2_5_VLProcessingInfo,
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
SupportsLoRA, SupportsPP,
SupportsQuant):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
@ -846,7 +846,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=self._maybe_ignore_quant_config(self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
@ -859,12 +859,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
return config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:

View File

@ -467,6 +467,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {

View File

@ -4,7 +4,7 @@
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import Callable, Literal, Optional, Protocol, Union, overload
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
@ -64,6 +64,19 @@ class WeightsMapper:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
def apply_list(self, values: list[str]) -> list[str]:
return [
out_name for name in values
if (out_name := self._map_name(name)) is not None
]
def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
return {
out_name: value
for name, value in values.items()
if (out_name := self._map_name(name)) is not None
}
class AutoWeightsLoader:
"""

View File

@ -58,7 +58,8 @@ def _make_synced_weight_loader(original_weight_loader):
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
# don't infer mapping if the model has defined it explicitly.
if parent_map:
@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "packed_modules_mapping", {})
child_map = getattr(child, "packed_modules_mapping", None)
child_map = copy.deepcopy(child_map) if child_map is not None else {}
if any((k in parent_map and parent_map[k] != v)
for k, v in child_map.items()):
raise ValueError(