mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-14 04:07:11 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
229 lines
8.1 KiB
Python
229 lines
8.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import gguf
|
|
import torch
|
|
from gguf import GGMLQuantizationType as WeightType
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
|
|
class GGUFConfig(QuantizationConfig):
|
|
"""Config class for GGUF."""
|
|
|
|
def __init__(self, ) -> None:
|
|
pass
|
|
|
|
def __repr__(self) -> str:
|
|
return ("GGUFConfig()")
|
|
|
|
def get_name(self) -> str:
|
|
return "gguf"
|
|
|
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
|
return [torch.half, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 60
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return [] # no extra configs.
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
|
|
return cls()
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
if isinstance(layer, LinearBase):
|
|
return GGUFLinearMethod(self)
|
|
elif isinstance(layer, VocabParallelEmbedding):
|
|
return GGUFEmbeddingMethod(self)
|
|
return None
|
|
|
|
|
|
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
|
STANDARD_QUANT_TYPES = {
|
|
WeightType.Q4_0,
|
|
WeightType.Q4_1,
|
|
WeightType.Q5_0,
|
|
WeightType.Q5_1,
|
|
WeightType.Q8_0,
|
|
WeightType.Q8_1,
|
|
}
|
|
KQUANT_TYPES = {
|
|
WeightType.Q2_K,
|
|
WeightType.Q3_K,
|
|
WeightType.Q4_K,
|
|
WeightType.Q5_K,
|
|
WeightType.Q6_K,
|
|
}
|
|
IMATRIX_QUANT_TYPES = {
|
|
WeightType.IQ1_M,
|
|
WeightType.IQ1_S,
|
|
WeightType.IQ2_XXS,
|
|
WeightType.IQ2_XS,
|
|
WeightType.IQ2_S,
|
|
WeightType.IQ3_XXS,
|
|
WeightType.IQ3_S,
|
|
WeightType.IQ4_XS,
|
|
WeightType.IQ4_NL,
|
|
}
|
|
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
|
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
|
# MMQ kernel for I-Matrix quantization.
|
|
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
|
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
|
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
|
|
|
|
|
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
|
qweight_type: int) -> torch.Tensor:
|
|
# there is no need to call any kernel for fp16/bf16
|
|
if qweight_type in UNQUANTIZED_TYPES:
|
|
return x @ qweight.T
|
|
# enable MMVQ in contiguous batching with batch_size=1
|
|
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
|
|
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
|
# Use MMQ Kernel if it's available (standard + k-quants)
|
|
elif qweight_type in MMQ_QUANT_TYPES:
|
|
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
|
# If there is no available MMQ kernel, fallback to dequantize
|
|
elif qweight_type in DEQUANT_TYPES:
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
|
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
|
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
|
y = x @ weight.T
|
|
else:
|
|
# Raise an error if the quantization type is not supported.
|
|
# Might be useful if llama.cpp adds a new quantization type.
|
|
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
|
qweight_type = WeightType(qweight_type)
|
|
raise NotImplementedError(
|
|
f"Unsupported GGUF quantization type: {qweight_type}")
|
|
return y
|
|
|
|
|
|
class GGUFLinearMethod(LinearMethodBase):
|
|
"""Linear method for GGUF.
|
|
|
|
Args:
|
|
quant_config: The GGUF quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: GGUFConfig):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: List[int], input_size: int,
|
|
output_size: int, params_dtype: torch.dtype,
|
|
**extra_weight_attrs):
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
|
|
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
|
qweight = GGUFUninitializedParameter(requires_grad=False)
|
|
set_weight_attrs(
|
|
qweight, {
|
|
"input_dim": 1,
|
|
"output_dim": 0,
|
|
"tensor_shape": tensor_shape,
|
|
"is_gguf_weight": True,
|
|
"data_container": [],
|
|
"shard_id": [],
|
|
"shard_id_map": {},
|
|
})
|
|
set_weight_attrs(qweight, extra_weight_attrs)
|
|
layer.register_parameter("qweight", qweight)
|
|
|
|
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
|
|
dtype=torch.uint8),
|
|
requires_grad=False)
|
|
set_weight_attrs(
|
|
qweight_type, {
|
|
"is_gguf_weight_type": True,
|
|
"weight_type": 0,
|
|
"shard_weight_type": {},
|
|
"ignore_warning": True
|
|
})
|
|
set_weight_attrs(qweight_type, extra_weight_attrs)
|
|
layer.register_parameter("qweight_type", qweight_type)
|
|
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
shard_id = getattr(layer.qweight, "shard_id", None)
|
|
|
|
if shard_id:
|
|
# dequantize shard weights respectively
|
|
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
|
qweight = layer.qweight.unbind(0)
|
|
result = []
|
|
for idx in shard_id:
|
|
q_idx = layer.qweight.shard_id_map[idx]
|
|
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
|
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
|
|
out = torch.cat(result, axis=1)
|
|
else:
|
|
qweight = layer.qweight
|
|
qweight_type = layer.qweight_type.weight_type
|
|
out = _fuse_mul_mat(x, qweight, qweight_type)
|
|
if bias is not None:
|
|
out.add_(bias)
|
|
return out
|
|
|
|
|
|
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
|
"""Embedding method for GGUF.
|
|
|
|
Args:
|
|
quant_config: The GGUF quantization config.
|
|
"""
|
|
|
|
def embedding(self, layer: torch.nn.Module,
|
|
x: torch.Tensor) -> torch.Tensor:
|
|
qweight = layer.qweight
|
|
qweight_type = layer.qweight_type.weight_type
|
|
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
|
hidden_size = qweight.shape[1] // type_size * block_size
|
|
if qweight_type < 2:
|
|
return torch.embedding(qweight, x)
|
|
x_flat = x.flatten()
|
|
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
|
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
|
x_flat.shape[0])
|
|
return dequant.view(*x.shape, hidden_size)
|
|
|
|
|
|
class GGUFUninitializedParameter(UninitializedParameter):
|
|
cls_to_become = Parameter
|
|
data_container: List[torch.Tensor]
|
|
|
|
def materialize_nested(self) -> Parameter:
|
|
dtype = {data.dtype for data in self.data_container}
|
|
assert len(dtype) == 1, ValueError(
|
|
f"Data container has mixed dtypes: {dtype}")
|
|
dtype = next(iter(dtype))
|
|
nested_data = torch.nested.nested_tensor(self.data_container,
|
|
device=self.device,
|
|
dtype=dtype)
|
|
self.data_container.clear()
|
|
param = torch.Tensor._make_subclass(self.cls_to_become,
|
|
nested_data,
|
|
require_grad=False)
|
|
for k, v in self.__dict__.items():
|
|
setattr(param, k, v)
|
|
return param
|