[Misc] GPTQ Activation Ordering (#8135)

This commit is contained in:
Kyle Sayers 2024-09-09 16:27:26 -04:00 committed by GitHub
parent f9b4a2d415
commit c7cb5c3335
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 15 deletions

View File

@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main

View File

@ -232,7 +232,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions

View File

@ -5,14 +5,18 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"]
@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# Handle sorting for activation reordering if needed.
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=layer.input_size_per_partition,
size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,

View File

@ -1,8 +1,8 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token"
class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""
GROUP = "group"
WEIGHT = "weight"
class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
num_bits: int = 8
@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"),
)
@field_validator("actorder", mode="before")
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
if isinstance(value, bool):
return ActivationOrdering.GROUP if value else None
if isinstance(value, str):
return ActivationOrdering(value.lower())
return value
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [