mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 14:46:00 +08:00
[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels (#17112)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
c6f36cfa26
commit
89ac266b26
@ -26,9 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
@ -74,7 +75,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -299,6 +300,22 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||
and weight_quant.symmetric and is_dynamic)
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
@ -374,7 +391,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
@ -443,6 +459,16 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||
is_static_input_scheme = (input_quant
|
||||
and not input_quant.dynamic)
|
||||
return CompressedTensorsW4A8Int(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
@ -20,5 +21,5 @@ __all__ = [
|
||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Fp4"
|
||||
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
|
||||
]
|
||||
|
||||
@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Int"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
is_static_input_scheme: bool = False,
|
||||
input_symmetric: bool = True):
|
||||
self.strategy = strategy
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}."
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
|
||||
# Compute effective group_size
|
||||
if self.group_size == -1:
|
||||
effective_group_size = (input_size_per_partition
|
||||
if row_parallel else input_size)
|
||||
else:
|
||||
effective_group_size = self.group_size
|
||||
|
||||
# Ensure group_size divides input_size_per_partition
|
||||
assert input_size_per_partition % effective_group_size == 0, (
|
||||
f"input_size_per_partition {input_size_per_partition}"
|
||||
f" not divisible by group_size {effective_group_size}")
|
||||
|
||||
# Determine scale partitioning
|
||||
is_channelwise = (self.group_size == -1)
|
||||
repeat_scales = (is_channelwise and row_parallel)
|
||||
partition_scales = not repeat_scales
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(input_size_per_partition,
|
||||
output_size_per_partition),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=effective_group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Int",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
scales_and_zp_size = input_size_per_partition // effective_group_size
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype)
|
||||
}
|
||||
|
||||
if partition_scales:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_gidx_param_name=None)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
|
||||
BitBLASLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||
ConchLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||
Dynamic4bitLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
@ -25,6 +27,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
Dynamic4bitLinearKernel,
|
||||
BitBLASLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
@ -56,7 +59,8 @@ def choose_mp_linear_kernel(
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
@ -64,12 +68,12 @@ def choose_mp_linear_kernel(
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
continue
|
||||
|
||||
if kernel.get_min_capability() > compute_capability:
|
||||
if (compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
|
||||
@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if current_platform.get_cpu_architecture(
|
||||
) == CpuArchEnum.ARM and c.act_type not in [
|
||||
torch.float32,
|
||||
]:
|
||||
return False, "Dynamic4bitLinearKernel on Arm requires"\
|
||||
" Float32 activations"
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return False, f"PyTorch {torch.__version__} does not support"\
|
||||
" _dyn_quant_matmul_4bit. Install a newer version"
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4
|
||||
| packed_weight[::, ::2]).to(torch.uint8)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
|
||||
# Handle scaling factors for partitioned weights
|
||||
if block_size == c.partition_weight_shape[0]:
|
||||
scales = scales.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
layer.bias = layer.bias.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 bias
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed, scales, layer.bias, block_size,
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1])
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(w, requires_grad=False))
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
c = self.config
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
return output.reshape(out_shape)
|
||||
Loading…
x
Reference in New Issue
Block a user