165 lines
5.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
# Not required functions
def embedding(self, layer: torch.nn.Module, *args,
**kwargs) -> torch.Tensor:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
def method_has_implemented_embedding(
method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None
and class_embedding is not base_embedding)
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return None
@staticmethod
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@staticmethod
def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default
@abstractmethod
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> Optional[str]:
return None
def apply_vllm_mapper( # noqa: B027
self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass