Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

295 lines
11 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
scalar_types.uint4b8, scalar_types.uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
"""
def __init__(
self,
weight_bits: int,
group_size: int,
) -> None:
quant_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}.get(weight_bits)
self.group_size = group_size
# Verify
if quant_type is None or \
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
raise ValueError(
f"Marlin_24 does not support quant_type = {quant_type}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
self.quant_type = quant_type
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.quant_type.size_bits
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(quant_type={}, group_size={})".format(
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
return "gptq_marlin_24"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_marlin_24")
if is_marlin_24_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. "
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self)
return None
class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24.
Args:
quant_config: The Marlin24 quantization config.
"""
def __init__(self, quant_config: GPTQMarlin24Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
# Meta
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)
layer.register_parameter("B_24", qweight)
layer.register_parameter("B_meta", meta)
layer.register_parameter("s", scales)
layer.register_parameter("workspace", workspace)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B_24
meta = layer.B_meta
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.quant_type,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output