mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:45:16 +08:00
[Misc] Support register quantization method out-of-tree (#11969)
This commit is contained in:
parent
6d0e3d3724
commit
32eb0da808
117
tests/quantization/test_register_quantization_config.py
Normal file
117
tests/quantization/test_register_quantization_config.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""Tests register custom quantization config.
|
||||||
|
|
||||||
|
See https://github.com/vllm-project/vllm/issues/11926 for more details.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_register_quantization_config.py`.
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization import (
|
||||||
|
get_quantization_config, register_quantization_config)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
|
QuantizationConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeQuantLinearMethod(UnquantizedLinearMethod):
|
||||||
|
"""Fake quantization linear method for per-token dynamic quantization."""
|
||||||
|
|
||||||
|
def __init__(self, num_bits: int = 8) -> None:
|
||||||
|
"""Initialize the quantization method."""
|
||||||
|
super().__init__()
|
||||||
|
self.num_bits = num_bits
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: "torch.nn.Module",
|
||||||
|
x: "torch.Tensor",
|
||||||
|
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
|
||||||
|
"""Perform fake quantization before the linear layer."""
|
||||||
|
|
||||||
|
# Calculate the scales dynamically
|
||||||
|
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
|
||||||
|
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
|
||||||
|
scales = (max_val - min_val) / (2**self.num_bits - 1)
|
||||||
|
|
||||||
|
# Fake quantize the input
|
||||||
|
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
|
||||||
|
2**(self.num_bits - 1) - 1)
|
||||||
|
dequant_x = quant_x * scales
|
||||||
|
|
||||||
|
return F.linear(dequant_x, layer.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@register_quantization_config("custom_quant")
|
||||||
|
class CustomQuantConfig(QuantizationConfig):
|
||||||
|
"""Custom quantization config for per-token dynamic fake quantization."""
|
||||||
|
|
||||||
|
def __init__(self, num_bits: int = 8) -> None:
|
||||||
|
"""Initialize the quantization config."""
|
||||||
|
self.num_bits = num_bits
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
"""Name of the quantization method."""
|
||||||
|
return "custom_quant"
|
||||||
|
|
||||||
|
def get_supported_act_dtypes(self) -> List["torch.dtype"]:
|
||||||
|
"""List of supported activation dtypes."""
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
"""Minimum GPU capability to support the quantization method."""
|
||||||
|
return -1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_config_filenames() -> List[str]:
|
||||||
|
"""List of filenames to search for in the model directory."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
|
||||||
|
"""Create a config class from the model's quantization config."""
|
||||||
|
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: "torch.nn.Module",
|
||||||
|
prefix: str) -> Optional["FakeQuantLinearMethod"]:
|
||||||
|
"""Get the quantize method to use for the quantized layer."""
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return FakeQuantLinearMethod(num_bits=self.num_bits)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_quantization_config():
|
||||||
|
"""Test register custom quantization config."""
|
||||||
|
|
||||||
|
# The quantization method `custom_quant` should be registered.
|
||||||
|
assert get_quantization_config("custom_quant") == CustomQuantConfig
|
||||||
|
|
||||||
|
# The quantization method `custom_quant` is already exists,
|
||||||
|
# should raise an error.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
register_quantization_config("custom_quant")(CustomQuantConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(argnames="model",
|
||||||
|
argvalues=[
|
||||||
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
])
|
||||||
|
def test_custom_quant(vllm_runner, model):
|
||||||
|
"""Test infer with the custom quantization method."""
|
||||||
|
with vllm_runner(model_name=model,
|
||||||
|
quantization="custom_quant",
|
||||||
|
enforce_eager=True) as llm:
|
||||||
|
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
# Check the quantization method is FakeQuantLinearMethod
|
||||||
|
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
assert output
|
||||||
@ -29,6 +29,45 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"quark"
|
"quark"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# The customized quantization methods which will be added to this dict.
|
||||||
|
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_quantization_config(quantization: str):
|
||||||
|
"""Register a customized vllm quantization config.
|
||||||
|
|
||||||
|
When a quantization method is not supported by vllm, you can register a customized
|
||||||
|
quantization config to support it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quantization (str): The quantization method name.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from vllm.model_executor.layers.quantization import register_quantization_config
|
||||||
|
>>> from vllm.model_executor.layers.quantization import get_quantization_config
|
||||||
|
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
>>>
|
||||||
|
>>> @register_quantization_config("my_quant")
|
||||||
|
... class MyQuantConfig(QuantizationConfig):
|
||||||
|
... pass
|
||||||
|
>>>
|
||||||
|
>>> get_quantization_config("my_quant")
|
||||||
|
<class 'MyQuantConfig'>
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def _wrapper(quant_config_cls):
|
||||||
|
if quantization in QUANTIZATION_METHODS:
|
||||||
|
raise ValueError(
|
||||||
|
f"The quantization method `{quantization}` is already exists.")
|
||||||
|
if not issubclass(quant_config_cls, QuantizationConfig):
|
||||||
|
raise ValueError("The quantization config must be a subclass of "
|
||||||
|
"`QuantizationConfig`.")
|
||||||
|
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
|
||||||
|
QUANTIZATION_METHODS.append(quantization)
|
||||||
|
return quant_config_cls
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||||
if quantization not in QUANTIZATION_METHODS:
|
if quantization not in QUANTIZATION_METHODS:
|
||||||
@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
"ipex": IPEXConfig,
|
"ipex": IPEXConfig,
|
||||||
"quark": QuarkConfig
|
"quark": QuarkConfig
|
||||||
}
|
}
|
||||||
|
# Update the `method_to_config` with customized quantization methods.
|
||||||
|
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||||
|
|
||||||
return method_to_config[quantization]
|
return method_to_config[quantization]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user