[Quant] Add SupportsQuant to phi3 and clip (#13104)

This commit is contained in:
Kyle Sayers 2025-02-15 22:28:33 -05:00 committed by GitHub
parent 80f63a3966
commit 12913d17ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 67 additions and 13 deletions

View File

@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig):
num_codebooks: int,
out_group_size: int,
) -> None:
super().__init__()
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks

View File

@ -26,6 +26,7 @@ class AWQConfig(QuantizationConfig):
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point

View File

@ -47,6 +47,7 @@ class AWQMarlinConfig(QuantizationConfig):
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point

View File

@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Type
from typing import Any, Dict, List, Optional, Type
import torch
from torch import nn
@ -59,7 +59,11 @@ def method_has_implemented_embedding(
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:

View File

@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype

View File

@ -51,7 +51,7 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: Optional[Dict[str, Any]] = None,
config: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]

View File

@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
weight_bits: int = 8,
group_size: int = 512,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16]

View File

@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
def __init__(self) -> None:
pass
super().__init__()
@classmethod
def get_name(cls) -> str:

View File

@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
super().__init__()
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub

View File

@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig):
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "

View File

@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
pass
super().__init__()
def __repr__(self) -> str:
return ("GGUFConfig()")

View File

@ -58,6 +58,7 @@ class GPTQConfig(QuantizationConfig):
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
super().__init__()
self.dynamic = dynamic
self.weight_bits = weight_bits

View File

@ -46,6 +46,7 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)

View File

@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits: int,
group_size: int,
) -> None:
super().__init__()
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,

View File

@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig):
group_size: int,
skip_modules: Optional[List[str]] = None,
) -> None:
super().__init__()
assert group_size == 64, ("The only supported HQQ group size is "
"currently 64.")
assert weight_bits == 4, ("The only supported HQQ quantization "

View File

@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig):
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
super().__init__()
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size

View File

@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig):
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"

View File

@ -24,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig):
group_size: int, has_zp: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.has_zp = has_zp

View File

@ -20,6 +20,7 @@ class NeuronQuantConfig(QuantizationConfig):
dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
super().__init__()
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError(

View File

@ -39,6 +39,7 @@ class QQQConfig(QuantizationConfig):
group_size: int,
is_sym: bool = True,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.is_sym = is_sym

View File

@ -30,6 +30,7 @@ class QuarkConfig(QuantizationConfig):
kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None,
pack_method: str = "reorder"):
super().__init__()
if kv_cache_group is None:
kv_cache_group = []
self.quant_config = quant_config

View File

@ -21,6 +21,7 @@ class Int8TpuConfig(QuantizationConfig):
self,
activation_scheme: str = "none",
) -> None:
super().__init__()
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")

View File

@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -335,10 +336,10 @@ class CLIPVisionTransformer(nn.Module):
return encoder_outputs
class CLIPVisionModel(nn.Module):
class CLIPVisionModel(nn.Module, SupportsQuant):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(
self,

View File

@ -7,6 +7,8 @@ import torch
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import supports_kw
from .interfaces_base import is_pooling_model
@ -443,6 +445,36 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
class SupportsQuant:
"""The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)
return instance
@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
from vllm.config import VllmConfig # avoid circular import
args_values = list(args) + list(kwargs.values())
for arg in args_values:
if isinstance(arg, VllmConfig):
return arg.quant_config
if isinstance(arg, QuantizationConfig):
return arg
return None
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""

View File

@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@ -498,7 +498,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
@ -520,14 +520,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
quant_config,
self.quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
self.language_model = init_vllm_registered_model(