mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:35:34 +08:00
[Feature] A calibration-free RTN-based quantization for accurate and accelerated INT4/INT8 inference (#18768)
Signed-off-by: Alex Kogan <alex.kogan@oracle.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
bd5038af07
commit
27949354fa
28
tests/quantization/test_rtn.py
Normal file
28
tests/quantization/test_rtn.py
Normal file
@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
"""Tests RTN quantization startup and generation,
|
||||
doesn't test correctness
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("rtn"),
|
||||
reason="RTN is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_model_rtn_startup(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
@ -35,6 +35,7 @@ QuantizationMethods = Literal[
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
"auto-round",
|
||||
"rtn",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
from .rtn import RTNConfig
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
"auto-round": AutoRoundConfig,
|
||||
"rtn": RTNConfig
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
288
vllm/model_executor/layers/quantization/rtn.py
Normal file
288
vllm/model_executor/layers/quantization/rtn.py
Normal file
@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""By default, use 8 bit as target precision, but it can be
|
||||
overridden by setting the RTN_NUM_BITS envvar
|
||||
"""
|
||||
NUM_BITS = os.getenv('RTN_NUM_BITS', "8")
|
||||
"""By default, use group size of 128 parameters, but it can be
|
||||
overridden by setting the RTN_GROUP_SIZE envvar
|
||||
"""
|
||||
GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128")
|
||||
|
||||
|
||||
class RTNConfig(QuantizationConfig):
|
||||
"""Config class for RTN.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int = int(NUM_BITS),
|
||||
group_size: int = int(GROUP_SIZE),
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
|
||||
if self.weight_bits != 4 and self.weight_bits != 8:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit or 8-bit weight quantization is "
|
||||
f"supported for RTN, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RTNConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "rtn"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["RTNLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return RTNLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class RTNTensor:
|
||||
"""A wrapper over Tensor that enables quantization on-the-fly by
|
||||
overloading the copy_ method.
|
||||
"""
|
||||
|
||||
def __init__(self, data: torch.Tensor, scale: torch.Tensor,
|
||||
quant_config: RTNConfig) -> None:
|
||||
self.data = data
|
||||
self.scale = scale
|
||||
self.quant_config = quant_config
|
||||
|
||||
def narrow(self, dim, start, length):
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
return RTNTensor(
|
||||
self.data.narrow(dim, start // factor, length // factor),
|
||||
self.scale.narrow(dim, start, length), self.quant_config)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
shape = self.data.shape
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
return torch.Size((shape[0] * factor, shape[1]))
|
||||
|
||||
def copy_(self, loaded_weight: torch.Tensor) -> None:
|
||||
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size)
|
||||
|
||||
self.data.copy_(qweight)
|
||||
self.scale.data.copy_(weight_scale)
|
||||
|
||||
|
||||
class RTNParameter(Parameter):
|
||||
"""A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
|
||||
when its data is accessed. We need this wrapper for the data loading phase
|
||||
only, so we can intercept a weight copying function (torch.Tensor.copy_)
|
||||
and apply quantization on-the-fly.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
def __init__(self, data: torch.Tensor, scale: torch.Tensor,
|
||||
quant_config: RTNConfig) -> None:
|
||||
self.scale = scale
|
||||
self.quant_config = quant_config
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return RTNTensor(super().data, self.scale, self.quant_config)
|
||||
|
||||
|
||||
class RTNLinearMethod(LinearMethodBase):
|
||||
"""Linear method for RTN.
|
||||
|
||||
Args:
|
||||
quant_config: The RTN quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: RTNConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
num_groups_per_col = (input_size_per_partition //
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1 else 1)
|
||||
|
||||
scale = Parameter(
|
||||
torch.empty(output_size_per_partition,
|
||||
num_groups_per_col,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||
|
||||
weight = RTNParameter(data=torch.empty(output_size_per_partition //
|
||||
factor,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
scale=scale,
|
||||
quant_config=self.quant_config)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
layer.register_parameter("scale", scale)
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""torch.compile does not know how to deal with a Parameter subclass
|
||||
(aka RTNParameter). As we don't really need RTNParameters for the
|
||||
forward pass, we replace them with equivalent instances of Parameters.
|
||||
"""
|
||||
old_weight = layer.weight
|
||||
assert isinstance(old_weight, RTNParameter)
|
||||
data = old_weight.data.data
|
||||
|
||||
delattr(layer, "weight")
|
||||
|
||||
new_weight = Parameter(data=data, requires_grad=False)
|
||||
layer.register_parameter("weight", new_weight)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.weight
|
||||
scale = layer.scale
|
||||
|
||||
weight = rtn_dequantize(qweight, scale)
|
||||
out = F.linear(x, weight)
|
||||
del weight
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize a tensor using per-group static scaling factor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
num_bits: Target precision for the result (supported values are
|
||||
8 or 4).
|
||||
group_size: Quantization granularity.
|
||||
If equal to -1, each row in the input tensor is treated
|
||||
as one group.
|
||||
"""
|
||||
|
||||
q_range = 2**num_bits
|
||||
num_groups = (tensor.shape[0] * tensor.shape[1] //
|
||||
group_size if group_size != -1 else tensor.shape[0])
|
||||
"""Calculate a scaling factor per input group.
|
||||
"""
|
||||
input_flat = tensor.reshape(num_groups, -1)
|
||||
input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
|
||||
input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
|
||||
input_max_abs = torch.max(input_min.abs(), input_max.abs())
|
||||
scale = (input_max_abs * 2.0 / (q_range - 1))
|
||||
"""Scale each input group, truncate and round to the nearest integer.
|
||||
"""
|
||||
scaled_input = input_flat / scale
|
||||
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
|
||||
scaled_input = scaled_input.round()
|
||||
|
||||
scale = scale.reshape(tensor.shape[0], -1).contiguous()
|
||||
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
|
||||
inputs_q = inputs_q.contiguous()
|
||||
|
||||
if num_bits == 4:
|
||||
"""Pack two 4-bit values into each byte.
|
||||
"""
|
||||
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
|
||||
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
|
||||
inputs_q = inputs_q.contiguous()
|
||||
|
||||
return inputs_q, scale
|
||||
|
||||
|
||||
def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize a tensor using per-group static scaling factors.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
scale: The tensor with per-group scale factors.
|
||||
"""
|
||||
|
||||
num_groups = scale.size(0) * scale.size(1)
|
||||
input_dim, output_dim = tensor.shape
|
||||
|
||||
num_bits = 8 if input_dim == scale.size(0) else 4
|
||||
if num_bits == 4:
|
||||
input_dim *= 2
|
||||
|
||||
data = torch.empty((input_dim, output_dim),
|
||||
dtype=scale.dtype,
|
||||
device=tensor.device)
|
||||
|
||||
if num_bits == 8:
|
||||
data.copy_(tensor)
|
||||
else:
|
||||
"""Unpack two 4-bit values from each byte.
|
||||
"""
|
||||
tensor = tensor.reshape(input_dim, output_dim // 2)
|
||||
for i in range(2):
|
||||
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
|
||||
"""Scale each input group with its scaling factor.
|
||||
"""
|
||||
scale = scale.reshape(num_groups, -1)
|
||||
data = data.reshape(num_groups, -1)
|
||||
data = torch.mul(data, scale)
|
||||
|
||||
input_deq = data.reshape((input_dim, output_dim)).contiguous()
|
||||
return input_deq
|
||||
Loading…
x
Reference in New Issue
Block a user