mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 03:55:40 +08:00
[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels (#17112)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
c6f36cfa26
commit
89ac266b26
@ -26,9 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
|
||||||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
|
||||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
|
||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||||
|
CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
find_matched_target, is_activation_quantization_format,
|
find_matched_target, is_activation_quantization_format,
|
||||||
should_ignore_layer)
|
should_ignore_layer)
|
||||||
@ -74,7 +75,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
|
|
||||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||||
return [torch.float16, torch.bfloat16]
|
return [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -299,6 +300,22 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
# Only symmetric weight quantization supported.
|
# Only symmetric weight quantization supported.
|
||||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||||
|
|
||||||
|
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
|
||||||
|
input_quant: BaseModel) -> bool:
|
||||||
|
is_weight_4_bits = weight_quant.num_bits == 4
|
||||||
|
is_activation_8_bits = input_quant.num_bits == 8
|
||||||
|
weight_strategy = (
|
||||||
|
weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||||
|
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||||
|
is_token = (weight_strategy and input_quant.strategy
|
||||||
|
== QuantizationStrategy.TOKEN.value)
|
||||||
|
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||||
|
|
||||||
|
# Both symmetric and asymmetric input quantization supported.
|
||||||
|
# Only symmetric weight quantization supported.
|
||||||
|
return (is_weight_4_bits and is_activation_8_bits and is_token
|
||||||
|
and weight_quant.symmetric and is_dynamic)
|
||||||
|
|
||||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
# Confirm weights and activations quantized.
|
# Confirm weights and activations quantized.
|
||||||
@ -374,7 +391,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def _get_scheme_from_parts(
|
def _get_scheme_from_parts(
|
||||||
self, weight_quant: BaseModel,
|
self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
# Detect If Mixed Precision
|
# Detect If Mixed Precision
|
||||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A16Fp4()
|
return CompressedTensorsW4A16Fp4()
|
||||||
@ -443,6 +459,16 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
is_static_input_scheme=False,
|
is_static_input_scheme=False,
|
||||||
input_symmetric=input_quant.symmetric)
|
input_symmetric=input_quant.symmetric)
|
||||||
|
|
||||||
|
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||||
|
is_static_input_scheme = (input_quant
|
||||||
|
and not input_quant.dynamic)
|
||||||
|
return CompressedTensorsW4A8Int(
|
||||||
|
num_bits=weight_quant.num_bits,
|
||||||
|
strategy=weight_quant.strategy,
|
||||||
|
group_size=weight_quant.group_size,
|
||||||
|
is_static_input_scheme=is_static_input_scheme,
|
||||||
|
input_symmetric=input_quant.symmetric)
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"No compressed-tensors compatible scheme was found.")
|
"No compressed-tensors compatible scheme was found.")
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||||
|
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||||
CompressedTensorsW4A16Sparse24)
|
CompressedTensorsW4A16Sparse24)
|
||||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||||
@ -20,5 +21,5 @@ __all__ = [
|
|||||||
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
|
||||||
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
|
||||||
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
|
||||||
"CompressedTensorsW4A4Fp4"
|
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -0,0 +1,135 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
|
GroupQuantScaleParameter,
|
||||||
|
ModelWeightParameter)
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW4A8Int"]
|
||||||
|
W4A8_SUPPORTED_TYPES_MAP = {
|
||||||
|
4: scalar_types.int4,
|
||||||
|
}
|
||||||
|
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||||
|
_kernel_backends_being_used: set[str] = set()
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
strategy: str,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: Optional[int] = None,
|
||||||
|
is_static_input_scheme: bool = False,
|
||||||
|
input_symmetric: bool = True):
|
||||||
|
self.strategy = strategy
|
||||||
|
self.group_size = -1 if group_size is None else group_size
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
self.input_symmetric = input_symmetric
|
||||||
|
|
||||||
|
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported num_bits = {num_bits}."
|
||||||
|
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
|
||||||
|
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||||
|
input_size: int, output_partition_sizes: list[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
row_parallel = (input_size != input_size_per_partition)
|
||||||
|
|
||||||
|
# Compute effective group_size
|
||||||
|
if self.group_size == -1:
|
||||||
|
effective_group_size = (input_size_per_partition
|
||||||
|
if row_parallel else input_size)
|
||||||
|
else:
|
||||||
|
effective_group_size = self.group_size
|
||||||
|
|
||||||
|
# Ensure group_size divides input_size_per_partition
|
||||||
|
assert input_size_per_partition % effective_group_size == 0, (
|
||||||
|
f"input_size_per_partition {input_size_per_partition}"
|
||||||
|
f" not divisible by group_size {effective_group_size}")
|
||||||
|
|
||||||
|
# Determine scale partitioning
|
||||||
|
is_channelwise = (self.group_size == -1)
|
||||||
|
repeat_scales = (is_channelwise and row_parallel)
|
||||||
|
partition_scales = not repeat_scales
|
||||||
|
|
||||||
|
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||||
|
full_weight_shape=(input_size, output_size),
|
||||||
|
partition_weight_shape=(input_size_per_partition,
|
||||||
|
output_size_per_partition),
|
||||||
|
weight_type=self.quant_type,
|
||||||
|
act_type=params_dtype,
|
||||||
|
group_size=effective_group_size,
|
||||||
|
zero_points=False,
|
||||||
|
has_g_idx=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for CompressedTensorsW4A8Int",
|
||||||
|
kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
|
scales_and_zp_size = input_size_per_partition // effective_group_size
|
||||||
|
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int8),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
weight_scale_args = {
|
||||||
|
"weight_loader":
|
||||||
|
weight_loader,
|
||||||
|
"data":
|
||||||
|
torch.empty(output_size_per_partition,
|
||||||
|
scales_and_zp_size,
|
||||||
|
dtype=params_dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if partition_scales:
|
||||||
|
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||||
|
input_dim=1,
|
||||||
|
**weight_scale_args)
|
||||||
|
else:
|
||||||
|
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||||
|
**weight_scale_args)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_packed", weight)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||||
|
w_q_param_name="weight_packed",
|
||||||
|
w_s_param_name="weight_scale",
|
||||||
|
w_zp_param_name=None,
|
||||||
|
w_gidx_param_name=None)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
|
|||||||
BitBLASLinearKernel)
|
BitBLASLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||||
ConchLinearKernel)
|
ConchLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||||
|
Dynamic4bitLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||||
ExllamaLinearKernel)
|
ExllamaLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||||
@ -25,6 +27,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
|||||||
MacheteLinearKernel,
|
MacheteLinearKernel,
|
||||||
AllSparkLinearKernel,
|
AllSparkLinearKernel,
|
||||||
MarlinLinearKernel,
|
MarlinLinearKernel,
|
||||||
|
Dynamic4bitLinearKernel,
|
||||||
BitBLASLinearKernel,
|
BitBLASLinearKernel,
|
||||||
ConchLinearKernel,
|
ConchLinearKernel,
|
||||||
ExllamaLinearKernel,
|
ExllamaLinearKernel,
|
||||||
@ -56,7 +59,8 @@ def choose_mp_linear_kernel(
|
|||||||
if current_platform is None:
|
if current_platform is None:
|
||||||
raise ValueError("Cannot determine compute capability")
|
raise ValueError("Cannot determine compute capability")
|
||||||
_cc = current_platform.get_device_capability()
|
_cc = current_platform.get_device_capability()
|
||||||
compute_capability = _cc[0] * 10 + _cc[1]
|
if _cc is not None:
|
||||||
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
failure_reasons = []
|
failure_reasons = []
|
||||||
for kernel in _POSSIBLE_KERNELS:
|
for kernel in _POSSIBLE_KERNELS:
|
||||||
@ -64,12 +68,12 @@ def choose_mp_linear_kernel(
|
|||||||
failure_reasons.append(
|
failure_reasons.append(
|
||||||
f' {kernel.__name__} disabled by environment variable')
|
f' {kernel.__name__} disabled by environment variable')
|
||||||
continue
|
continue
|
||||||
|
if (compute_capability is not None
|
||||||
if kernel.get_min_capability() > compute_capability:
|
and kernel.get_min_capability() > compute_capability):
|
||||||
failure_reasons.append(
|
failure_reasons.append(
|
||||||
f"{kernel.__name__} requires capability "
|
f"{kernel.__name__} requires capability "
|
||||||
f"{kernel.get_min_capability()}, current compute capability "
|
f"{kernel.get_min_capability()}, current compute "
|
||||||
f"is {compute_capability}")
|
f" capability is {compute_capability}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
can_implement, failure_reason = kernel.can_implement(config)
|
can_implement, failure_reason = kernel.can_implement(config)
|
||||||
|
|||||||
@ -0,0 +1,92 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||||
|
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||||
|
if not current_platform.is_cpu():
|
||||||
|
return False, "Only CPU is supported"
|
||||||
|
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||||
|
return False, f"Unsupported quant type {c.weight_type}"
|
||||||
|
if current_platform.get_cpu_architecture(
|
||||||
|
) == CpuArchEnum.ARM and c.act_type not in [
|
||||||
|
torch.float32,
|
||||||
|
]:
|
||||||
|
return False, "Dynamic4bitLinearKernel on Arm requires"\
|
||||||
|
" Float32 activations"
|
||||||
|
if c.full_weight_shape[0] % c.group_size != 0:
|
||||||
|
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||||
|
" the number of input features "\
|
||||||
|
f"({c.full_weight_shape[0]})"
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
try:
|
||||||
|
# Attempt to retrieve the operation
|
||||||
|
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||||
|
except AttributeError:
|
||||||
|
return False, f"PyTorch {torch.__version__} does not support"\
|
||||||
|
" _dyn_quant_matmul_4bit. Install a newer version"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
c = self.config
|
||||||
|
packed_weight = getattr(layer, self.w_q_name)
|
||||||
|
packed_weight = packed_weight.add(8)
|
||||||
|
uint8_packed = (packed_weight[::, 1::2] << 4
|
||||||
|
| packed_weight[::, ::2]).to(torch.uint8)
|
||||||
|
|
||||||
|
scales = getattr(layer, self.w_s_name)
|
||||||
|
block_size = c.group_size
|
||||||
|
|
||||||
|
# Handle scaling factors for partitioned weights
|
||||||
|
if block_size == c.partition_weight_shape[0]:
|
||||||
|
scales = scales.to(
|
||||||
|
torch.float32
|
||||||
|
) # Float32 & Bfloat16 variants requires float32 scales
|
||||||
|
scales = scales.view(-1, 1) # Channel-wise scales
|
||||||
|
if layer.bias is not None:
|
||||||
|
layer.bias = layer.bias.to(
|
||||||
|
torch.float32
|
||||||
|
) # Float32 & Bfloat16 variants requires float32 bias
|
||||||
|
else:
|
||||||
|
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||||
|
scales = scales.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Repack weights as per kernel requirement
|
||||||
|
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||||
|
uint8_packed, scales, layer.bias, block_size,
|
||||||
|
c.partition_weight_shape[0], c.partition_weight_shape[1])
|
||||||
|
replace_parameter(layer, self.w_q_name,
|
||||||
|
torch.nn.Parameter(w, requires_grad=False))
|
||||||
|
setattr(layer, self.w_s_name, None)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
x_2d = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||||
|
|
||||||
|
w_q = getattr(layer, self.w_q_name)
|
||||||
|
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||||
|
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
|
||||||
|
c.partition_weight_shape[1])
|
||||||
|
return output.reshape(out_shape)
|
||||||
Loading…
x
Reference in New Issue
Block a user