mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 21:05:58 +08:00
110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class QuantizeMethodBase(ABC):
|
|
"""Base class for different quantized methods."""
|
|
|
|
@abstractmethod
|
|
def create_weights(self, layer: torch.nn.Module, *weight_args,
|
|
**extra_weight_attrs):
|
|
"""Create weights for a layer.
|
|
|
|
The weights will be set as attributes of the layer."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
|
"""Apply the weights in layer to the input tensor.
|
|
|
|
Expects create_weights to have been called before on the layer."""
|
|
raise NotImplementedError
|
|
|
|
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
|
"""Process the weight after loading.
|
|
|
|
This can be used for example, to transpose weights for computation.
|
|
"""
|
|
return
|
|
|
|
|
|
class QuantizationConfig(ABC):
|
|
"""Base class for quantization configs."""
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
"""Name of the quantization method."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
|
"""List of supported activation dtypes."""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def get_min_capability(cls) -> int:
|
|
"""Minimum GPU capability to support the quantization method.
|
|
|
|
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
|
This requirement is due to the custom CUDA kernels used by the
|
|
quantization method.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_config_filenames() -> List[str]:
|
|
"""List of filenames to search for in the model directory."""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
|
"""Create a config class from the model's quantization config."""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def override_quantization_method(cls, hf_quant_cfg,
|
|
user_quant) -> Optional[str]:
|
|
"""
|
|
Detects if this quantization method can support a given checkpoint
|
|
format by overriding the user specified quantization method --
|
|
this method should only be overwritten by subclasses in exceptional
|
|
circumstances
|
|
"""
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
|
"""Get a value from the model's quantization config."""
|
|
for key in keys:
|
|
if key in config:
|
|
return config[key]
|
|
raise ValueError(f"Cannot find any of {keys} in the model's "
|
|
"quantization config.")
|
|
|
|
@abstractmethod
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
|
"""Get the quantize method to use for the quantized layer.
|
|
|
|
Args:
|
|
layer: The layer for the quant method.
|
|
Returns:
|
|
The quantize method. None if the given layer doesn't support quant
|
|
method.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
"""Returns the activation function names that should be post-scaled.
|
|
|
|
For now, this is only used by AWQ.
|
|
"""
|
|
raise NotImplementedError
|