mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 06:04:00 +08:00
[Bugfix] Fix GGUF inference with FP16 unquantized checkpoint (#10675)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
c411def234
commit
b98c62ba49
@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -49,19 +50,65 @@ class GGUFConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
WeightType.Q4_1,
|
||||
WeightType.Q5_0,
|
||||
WeightType.Q5_1,
|
||||
WeightType.Q8_0,
|
||||
WeightType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
WeightType.Q2_K,
|
||||
WeightType.Q3_K,
|
||||
WeightType.Q4_K,
|
||||
WeightType.Q5_K,
|
||||
WeightType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
WeightType.IQ1_M,
|
||||
WeightType.IQ1_S,
|
||||
WeightType.IQ2_XXS,
|
||||
WeightType.IQ2_XS,
|
||||
WeightType.IQ2_S,
|
||||
WeightType.IQ3_XXS,
|
||||
WeightType.IQ3_S,
|
||||
WeightType.IQ4_XS,
|
||||
WeightType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
# use dequantize mulmat for IQmatrix, mmq for k-quants
|
||||
if x.shape[0] == 1:
|
||||
# enable mmvq in contiguous batching
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
elif qweight_type >= 16:
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
|
||||
|
||||
@ -121,9 +168,9 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight.unbind(0)
|
||||
result = []
|
||||
for id in shard_id:
|
||||
q_idx = layer.qweight.shard_id_map[id]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[id]
|
||||
for idx in shard_id:
|
||||
q_idx = layer.qweight.shard_id_map[idx]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
||||
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
@ -163,9 +210,13 @@ class GGUFUninitializedParameter(UninitializedParameter):
|
||||
data_container: List[torch.Tensor]
|
||||
|
||||
def materialize_nested(self) -> Parameter:
|
||||
dtype = {data.dtype for data in self.data_container}
|
||||
assert len(dtype) == 1, ValueError(
|
||||
f"Data container has mixed dtypes: {dtype}")
|
||||
dtype = next(iter(dtype))
|
||||
nested_data = torch.nested.nested_tensor(self.data_container,
|
||||
device=self.device,
|
||||
dtype=torch.uint8)
|
||||
dtype=dtype)
|
||||
self.data_container.clear()
|
||||
param = torch.Tensor._make_subclass(self.cls_to_become,
|
||||
nested_data,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user