mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:27:27 +08:00
[Misc] GPTQ Activation Ordering (#8135)
This commit is contained in:
parent
f9b4a2d415
commit
c7cb5c3335
@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
|||||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||||
|
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||||
awq, casperhansen/mixtral-instruct-awq, main
|
awq, casperhansen/mixtral-instruct-awq, main
|
||||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||||
|
|||||||
@ -232,7 +232,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return CompressedTensorsWNA16(
|
return CompressedTensorsWNA16(
|
||||||
num_bits=weight_quant.num_bits,
|
num_bits=weight_quant.num_bits,
|
||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
group_size=weight_quant.group_size)
|
group_size=weight_quant.group_size,
|
||||||
|
actorder=weight_quant.actorder)
|
||||||
|
|
||||||
# Detect If Activation Quantization.
|
# Detect If Activation Quantization.
|
||||||
# TODO @dsikka: clean-up conditions
|
# TODO @dsikka: clean-up conditions
|
||||||
|
|||||||
@ -5,14 +5,18 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
ActivationOrdering)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||||
marlin_permute_scales, replace_tensor, verify_marlin_supported,
|
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||||
|
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
||||||
verify_marlin_supports_shape)
|
verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter,
|
||||||
|
RowvLLMParameter)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsWNA16"]
|
__all__ = ["CompressedTensorsWNA16"]
|
||||||
@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
strategy: str,
|
strategy: str,
|
||||||
num_bits: int,
|
num_bits: int,
|
||||||
group_size: Optional[int] = None):
|
group_size: Optional[int] = None,
|
||||||
|
actorder: Optional[ActivationOrdering] = None):
|
||||||
|
|
||||||
self.pack_factor = 32 // num_bits
|
self.pack_factor = 32 // num_bits
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.group_size = -1 if group_size is None else group_size
|
self.group_size = -1 if group_size is None else group_size
|
||||||
|
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||||
|
|
||||||
if self.group_size == -1 and self.strategy != "channel":
|
if self.group_size == -1 and self.strategy != "channel":
|
||||||
raise ValueError("Marlin kernels require group quantization or "
|
raise ValueError("Marlin kernels require group quantization or "
|
||||||
@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
# If group_size is -1, we are in channelwise case.
|
# If group_size is -1, we are in channelwise case.
|
||||||
channelwise = (self.group_size == -1)
|
|
||||||
group_size = self.group_size if self.group_size != -1 else input_size
|
group_size = self.group_size if self.group_size != -1 else input_size
|
||||||
row_parallel = (input_size != input_size_per_partition)
|
row_parallel = (input_size != input_size_per_partition)
|
||||||
# In the case of channelwise quantization, we need to replicate the
|
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||||
# scales across all gpus.
|
self.has_g_idx, self.group_size, row_parallel)
|
||||||
partition_scales = (row_parallel and not channelwise)
|
|
||||||
|
|
||||||
verify_marlin_supports_shape(
|
verify_marlin_supports_shape(
|
||||||
output_size_per_partition=output_size_per_partition,
|
output_size_per_partition=output_size_per_partition,
|
||||||
@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
layer.register_parameter("weight_shape", weight_shape)
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
|
||||||
|
# group index (for activation reordering)
|
||||||
|
if self.has_g_idx:
|
||||||
|
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
input_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
layer.input_size = input_size
|
layer.input_size = input_size
|
||||||
@ -137,8 +151,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
layer.workspace = marlin_make_workspace(
|
layer.workspace = marlin_make_workspace(
|
||||||
layer.output_size_per_partition, device)
|
layer.output_size_per_partition, device)
|
||||||
|
|
||||||
# Act-order not supported in compressed-tensors yet, so set to empty.
|
# Handle sorting for activation reordering if needed.
|
||||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
if self.has_g_idx:
|
||||||
|
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
|
||||||
|
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||||
|
replace_tensor(layer, "weight_g_idx", g_idx)
|
||||||
|
else:
|
||||||
|
layer.weight_g_idx = marlin_make_empty_g_idx(device)
|
||||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||||
|
|
||||||
# No zero-point
|
# No zero-point
|
||||||
@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
replace_tensor(layer, "weight_packed", marlin_qweight)
|
replace_tensor(layer, "weight_packed", marlin_qweight)
|
||||||
|
|
||||||
# Permute scales from compressed-tensors format to marlin format.
|
# Permute scales from compressed-tensors format to marlin format.
|
||||||
|
# scale is required on all partitions if activation reordering
|
||||||
marlin_scales = marlin_permute_scales(
|
marlin_scales = marlin_permute_scales(
|
||||||
layer.weight_scale,
|
layer.weight_scale,
|
||||||
size_k=layer.input_size_per_partition,
|
size_k=(layer.input_size
|
||||||
|
if self.has_g_idx else layer.input_size_per_partition),
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
group_size=layer.group_size)
|
group_size=layer.group_size)
|
||||||
replace_tensor(layer, "weight_scale", marlin_scales)
|
replace_tensor(layer, "weight_scale", marlin_scales)
|
||||||
@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
weight=layer.weight_packed,
|
weight=layer.weight_packed,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
weight_zp=layer.weight_zp,
|
weight_zp=layer.weight_zp,
|
||||||
g_idx=layer.g_idx,
|
g_idx=layer.weight_g_idx,
|
||||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
wtype=self.quant_type,
|
wtype=self.quant_type,
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Iterable, Optional
|
from typing import Any, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
|
|||||||
TOKEN = "token"
|
TOKEN = "token"
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationOrdering(str, Enum):
|
||||||
|
"""
|
||||||
|
Enum storing strategies for activation ordering
|
||||||
|
|
||||||
|
Group: reorder groups and weight\n
|
||||||
|
Weight: only reorder weight, not groups. Slightly lower latency and
|
||||||
|
accuracy compared to group actorder\n
|
||||||
|
"""
|
||||||
|
|
||||||
|
GROUP = "group"
|
||||||
|
WEIGHT = "weight"
|
||||||
|
|
||||||
|
|
||||||
class QuantizationArgs(BaseModel):
|
class QuantizationArgs(BaseModel):
|
||||||
"""
|
"""
|
||||||
User facing arguments used to define a quantization config
|
User facing arguments used to define a quantization config
|
||||||
@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
|
|||||||
observed with every sample. Defaults to False for static
|
observed with every sample. Defaults to False for static
|
||||||
quantization. Note that enabling dynamic quantization
|
quantization. Note that enabling dynamic quantization
|
||||||
will change the default observer to a memoryless one
|
will change the default observer to a memoryless one
|
||||||
|
:param actorder: whether to apply group quantization in decreasing order of
|
||||||
|
activation. Defaults to None for arbitrary ordering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_bits: int = 8
|
num_bits: int = 8
|
||||||
@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
|
|||||||
strategy: Optional[QuantizationStrategy] = None
|
strategy: Optional[QuantizationStrategy] = None
|
||||||
block_structure: Optional[str] = None
|
block_structure: Optional[str] = None
|
||||||
dynamic: bool = False
|
dynamic: bool = False
|
||||||
|
actorder: Union[ActivationOrdering, bool, None] = None
|
||||||
observer: str = Field(
|
observer: str = Field(
|
||||||
default="minmax",
|
default="minmax",
|
||||||
description=("The class to use to compute the quantization param - "
|
description=("The class to use to compute the quantization param - "
|
||||||
@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
|
|||||||
"Observers constructor excluding quantization range or symmetry"),
|
"Observers constructor excluding quantization range or symmetry"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("actorder", mode="before")
|
||||||
|
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return ActivationOrdering.GROUP if value else None
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return ActivationOrdering(value.lower())
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def is_activation_quantization_format(format: str) -> bool:
|
def is_activation_quantization_format(format: str) -> bool:
|
||||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user