From 12913d17bab18d21e01962c1ad729ad7440d4c01 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 15 Feb 2025 22:28:33 -0500 Subject: [PATCH] [Quant] Add `SupportsQuant` to phi3 and clip (#13104) --- .../layers/quantization/aqlm.py | 1 + .../model_executor/layers/quantization/awq.py | 1 + .../layers/quantization/awq_marlin.py | 1 + .../layers/quantization/base_config.py | 8 +++-- .../layers/quantization/bitsandbytes.py | 2 +- .../compressed_tensors/compressed_tensors.py | 2 +- .../layers/quantization/deepspeedfp.py | 1 + .../layers/quantization/experts_int8.py | 2 +- .../layers/quantization/fbgemm_fp8.py | 1 + .../model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/gguf.py | 2 +- .../layers/quantization/gptq.py | 1 + .../layers/quantization/gptq_marlin.py | 1 + .../layers/quantization/gptq_marlin_24.py | 1 + .../layers/quantization/hqq_marlin.py | 1 + .../layers/quantization/ipex_quant.py | 1 + .../layers/quantization/modelopt.py | 1 + .../layers/quantization/moe_wna16.py | 1 + .../layers/quantization/neuron_quant.py | 1 + .../model_executor/layers/quantization/qqq.py | 1 + .../layers/quantization/quark/quark.py | 1 + .../layers/quantization/tpu_int8.py | 1 + vllm/model_executor/models/clip.py | 5 +-- vllm/model_executor/models/interfaces.py | 32 +++++++++++++++++++ vllm/model_executor/models/phi3v.py | 10 +++--- 25 files changed, 67 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6c08d016c0f7b..10f5241f9a717 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig): num_codebooks: int, out_group_size: int, ) -> None: + super().__init__() self.in_group_size = in_group_size self.nbits_per_codebook = nbits_per_codebook self.num_codebooks = num_codebooks diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ff77af44d7707..227be1497d0ec 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -26,6 +26,7 @@ class AWQConfig(QuantizationConfig): zero_point: bool, modules_to_not_convert: Optional[List[str]] = None, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index de4009d7d04ac..111b3f74d50e0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -47,6 +47,7 @@ class AWQMarlinConfig(QuantizationConfig): lm_head_quantized: bool, modules_to_not_convert: Optional[List[str]], full_config: Dict[str, Any]) -> None: + super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.zero_point = zero_point diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index c0d8553c0df1a..5ef11546fd41b 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Mapping, Optional, Type +from typing import Any, Dict, List, Optional, Type import torch from torch import nn @@ -59,7 +59,11 @@ def method_has_implemented_embedding( class QuantizationConfig(ABC): """Base class for quantization configs.""" - packed_modules_mapping: Mapping[str, List[str]] = dict() + + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: Dict[str, List[str]] = dict() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 49d992d4cb071..33c2ca93ffa17 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig): llm_int8_skip_modules: Optional[List[str]] = None, llm_int8_threshold: float = 6.0, ) -> None: - + super().__init__() self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4c974d3131926..ce6c706fe3d27 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -51,7 +51,7 @@ class CompressedTensorsConfig(QuantizationConfig): kv_cache_scheme: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, ): - + super().__init__() self.ignore = ignore self.quant_format = quant_format # Map from [target -> scheme] diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index b4123650149f0..67934d37284e1 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig): weight_bits: int = 8, group_size: int = 512, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.valid_types = [torch.bfloat16, torch.float16] diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 87fbcf62ac1ed..663fb8bf5b8e6 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig): """Config class for Int8 experts quantization.""" def __init__(self) -> None: - pass + super().__init__() @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index da5ef36c51054..3bb8188f725c8 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" def __init__(self, ignore_list: List[str], input_scale_ub: float): + super().__init__() self.ignore_list = ignore_list if ignore_list else [] self.input_scale_ub = input_scale_ub diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 86e025310f4ef..f928ea7e23ca8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig): ignored_layers: Optional[List[str]] = None, weight_block_size: Optional[List[int]] = None, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning("Detected fp8 checkpoint. Please note that the " diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 86e6dbb5a5fbe..b1fecb32f4d80 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" def __init__(self, ) -> None: - pass + super().__init__() def __repr__(self) -> str: return ("GGUFConfig()") diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 6d1f0cc2eb4d5..09291c2bf1f0b 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -58,6 +58,7 @@ class GPTQConfig(QuantizationConfig): # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers # } + super().__init__() self.dynamic = dynamic self.weight_bits = weight_bits diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index f421dbd2ce2b3..9f960d9fd37f2 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -46,6 +46,7 @@ class GPTQMarlinConfig(QuantizationConfig): is_sym: bool, lm_head_quantized: bool, dynamic: Dict[str, Dict[str, Union[int, bool]]], full_config: Dict[str, Any]) -> None: + super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index cec984483fd8c..dd747e182e289 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig): weight_bits: int, group_size: int, ) -> None: + super().__init__() quant_type = { 4: scalar_types.uint4b8, 8: scalar_types.uint8b128, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 432f43688ff58..4edc9aa848a19 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig): group_size: int, skip_modules: Optional[List[str]] = None, ) -> None: + super().__init__() assert group_size == 64, ("The only supported HQQ group size is " "currently 64.") assert weight_bits == 4, ("The only supported HQQ quantization " diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 2531170ececf9..c09cc13cb276b 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig): desc_act: Optional[bool] = None, lm_head_quantized: Optional[bool] = None, ) -> None: + super().__init__() self.method = method self.weight_bits = weight_bits self.group_size = group_size diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 348e9bccd9b0a..050130de1c0f3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig): self, is_checkpoint_fp8_serialized: bool = False, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 30eb04698d812..da06ca3f70ecc 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -24,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig): group_size: int, has_zp: bool, lm_head_quantized: bool, modules_to_not_convert: Optional[List[str]], full_config: Dict[str, Any]) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.has_zp = has_zp diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index a8e8be207fd15..82954612fb2ad 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -20,6 +20,7 @@ class NeuronQuantConfig(QuantizationConfig): dequant_dtype: str = "f16", quantize_method: str = "vector_dynamic", ) -> None: + super().__init__() self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index 6e9d3dc6cb378..1e05917a5187b 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -39,6 +39,7 @@ class QQQConfig(QuantizationConfig): group_size: int, is_sym: bool = True, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.is_sym = is_sym diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index ba123565a0ecc..ca71da8b736a5 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -30,6 +30,7 @@ class QuarkConfig(QuantizationConfig): kv_cache_group: Optional[List[str]] = None, kv_cache_config: Optional[Dict[str, Any]] = None, pack_method: str = "reorder"): + super().__init__() if kv_cache_group is None: kv_cache_group = [] self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 3234fecaa3b35..14e5bcf6e5bbe 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -21,6 +21,7 @@ class Int8TpuConfig(QuantizationConfig): self, activation_scheme: str = "none", ) -> None: + super().__init__() if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( f"Unsupported activation scheme {activation_scheme}") diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 547f624478162..73c109a27ac7a 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsQuant from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs @@ -335,10 +336,10 @@ class CLIPVisionTransformer(nn.Module): return encoder_outputs -class CLIPVisionModel(nn.Module): - +class CLIPVisionModel(nn.Module, SupportsQuant): config_class = CLIPVisionConfig main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__( self, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index a0a1b69ad5027..bd6661d668d9f 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,6 +7,8 @@ import torch from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.utils import supports_kw from .interfaces_base import is_pooling_model @@ -443,6 +445,36 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) +class SupportsQuant: + """The interface required for all models that support quantization.""" + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + quant_config: Optional[QuantizationConfig] = None + + def __new__(cls, *args, **kwargs) -> "SupportsQuant": + instance = super().__new__(cls) + quant_config = cls._find_quant_config(*args, **kwargs) + if quant_config is not None: + instance.quant_config = quant_config + instance.quant_config.packed_modules_mapping.update( + cls.packed_modules_mapping) + return instance + + @staticmethod + def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + from vllm.config import VllmConfig # avoid circular import + + args_values = list(args) + list(kwargs.values()) + for arg in args_values: + if isinstance(arg, VllmConfig): + return arg.quant_config + + if isinstance(arg, QuantizationConfig): + return arg + + return None + + @runtime_checkable class SupportsTranscription(Protocol): """The interface required for all models that support transcription.""" diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 053390c521fc2..6bbfa40beed1b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import CLIPVisionModel -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -498,7 +498,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, dummy_inputs=Phi3VDummyInputsBuilder) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -520,14 +520,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - quant_config=quant_config, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( config, - quant_config, + self.quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) self.language_model = init_vllm_registered_model(