mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 11:04:29 +08:00
[Quant] Add SupportsQuant to phi3 and clip (#13104)
This commit is contained in:
parent
80f63a3966
commit
12913d17ba
@ -169,6 +169,7 @@ class AQLMConfig(QuantizationConfig):
|
||||
num_codebooks: int,
|
||||
out_group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_group_size = in_group_size
|
||||
self.nbits_per_codebook = nbits_per_codebook
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
@ -26,6 +26,7 @@ class AWQConfig(QuantizationConfig):
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
@ -47,6 +47,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -59,7 +59,11 @@ def method_has_implemented_embedding(
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
packed_modules_mapping: Mapping[str, List[str]] = dict()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
|
||||
@ -30,7 +30,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
llm_int8_skip_modules: Optional[List[str]] = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
|
||||
@ -51,7 +51,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
|
||||
@ -25,6 +25,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
weight_bits: int = 8,
|
||||
group_size: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.valid_types = [torch.bfloat16, torch.float16]
|
||||
|
||||
@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
"""Config class for Int8 experts quantization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
|
||||
@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: List[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ class Fp8Config(QuantizationConfig):
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning("Detected fp8 checkpoint. Please note that the "
|
||||
|
||||
@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
"""Config class for GGUF."""
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
@ -58,6 +58,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
@ -46,6 +46,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
|
||||
@ -38,6 +38,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
|
||||
@ -33,6 +33,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
"currently 64.")
|
||||
assert weight_bits == 4, ("The only supported HQQ quantization "
|
||||
|
||||
@ -35,6 +35,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.method = method
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
|
||||
@ -28,6 +28,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
|
||||
@ -24,6 +24,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
|
||||
@ -20,6 +20,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
dequant_dtype: str = "f16",
|
||||
quantize_method: str = "vector_dynamic",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
|
||||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
|
||||
raise ValueError(
|
||||
|
||||
@ -39,6 +39,7 @@ class QQQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
is_sym: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.is_sym = is_sym
|
||||
|
||||
@ -30,6 +30,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
kv_cache_group: Optional[List[str]] = None,
|
||||
kv_cache_config: Optional[Dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -21,6 +21,7 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
self,
|
||||
activation_scheme: str = "none",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
|
||||
@ -335,10 +336,10 @@ class CLIPVisionTransformer(nn.Module):
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
class CLIPVisionModel(nn.Module):
|
||||
|
||||
class CLIPVisionModel(nn.Module, SupportsQuant):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
from .interfaces_base import is_pooling_model
|
||||
@ -443,6 +445,36 @@ def supports_cross_encoding(
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
"""The interface required for all models that support quantization."""
|
||||
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
|
||||
instance = super().__new__(cls)
|
||||
quant_config = cls._find_quant_config(*args, **kwargs)
|
||||
if quant_config is not None:
|
||||
instance.quant_config = quant_config
|
||||
instance.quant_config.packed_modules_mapping.update(
|
||||
cls.packed_modules_mapping)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
|
||||
from vllm.config import VllmConfig # avoid circular import
|
||||
|
||||
args_values = list(args) + list(kwargs.values())
|
||||
for arg in args_values:
|
||||
if isinstance(arg, VllmConfig):
|
||||
return arg.quant_config
|
||||
|
||||
if isinstance(arg, QuantizationConfig):
|
||||
return arg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsTranscription(Protocol):
|
||||
"""The interface required for all models that support transcription."""
|
||||
|
||||
@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -498,7 +498,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
|
||||
info=Phi3VProcessingInfo,
|
||||
dummy_inputs=Phi3VDummyInputsBuilder)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.wte": "embed_tokens",
|
||||
@ -510,7 +511,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
@ -520,14 +520,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config,
|
||||
quant_config,
|
||||
self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user