mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:25:01 +08:00
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Signed-off-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Animesh Jain <anijain@umich.edu> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: kf <kuanfu.liu@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: yan <yan.ma@intel.com> Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Xiao Liu <xiszishu@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Andy Xie <andy.xning@gmail.com> Signed-off-by: Haibin Lin <haibin.lin@bytedance.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: Abirdcfly <fp544037857@gmail.com> Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: huangweixiao <huangweixiao@msh.team> Signed-off-by: alyosha-swamy <raghav@arcee.ai> Signed-off-by: Eric Hanley <ericehanley@google.com> Signed-off-by: Abatom <abzhonghua@gmail.com> Signed-off-by: CLFutureX <775523362@qq.com> Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: tlipoca9 <tlipoca9@gmail.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Benji Beck <benjibeck@meta.com> Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Zhang Jason <ning.zhang2@amd.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: QscQ <qscqesze@gmail.com> Signed-off-by: qingjun <qingjun@minimaxi.com> Signed-off-by: Syed Muhammad Bin Asif <syedmba7@connect.hku.hk> Signed-off-by: Lionel Villard <villard@us.ibm.com> Signed-off-by: ycyaw66 <497410282@qq.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: Linkun <github@lkchen.net> Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai> Signed-off-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Signed-off-by: Ricardo Decal <rdecal@anyscale.com> Signed-off-by: Andrew Chan <andrewkchan.akc@gmail.com> Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Signed-off-by: Junhao Li <junhao@ubicloud.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Guy Stone <guys@spotify.com> Signed-off-by: <yyweiss@gmail.com> Signed-off-by: yyw <yyweiss@gmail.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Signed-off-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Animesh Jain <jainanimesh2305@yahoo.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: XiongfeiWei <isaacwxf23@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JartX <sagformas@gmail.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: kf <kuanfu.liu@embeddedllm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Xiao <xiszishu@gmail.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Ning Xie <andy.xning@gmail.com> Co-authored-by: H <linhaibin.eric@gmail.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: TankNee <nee@tanknee.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: ZiTian.Zhao <zitian.zhao@tencentmusic.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Abirdcfly <fp544037857@gmail.com> Co-authored-by: Giancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@meta.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Weixiao Huang <hwx.simle@gmail.com> Co-authored-by: Raghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com> Co-authored-by: ericehanley <ericehanley@google.com> Co-authored-by: Zhonghua Deng <abzhonghua@gmail.com> Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com> Co-authored-by: PiteXChen <44110731+CLFutureX@users.noreply.github.com> Co-authored-by: lkchen <github@lkchen.net> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: tlipoca9 <160737620+tlipoca9@users.noreply.github.com> Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Benji Beck <benjibeck@meta.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Zhang Jason <ning.zhang2@amd.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: Lain <siyuanf@nvidia.com> Co-authored-by: tc-mb <157115220+tc-mb@users.noreply.github.com> Co-authored-by: imning3 <hbning@pku.edu.cn> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: qscqesze <qingjun@minimaxi.com> Co-authored-by: Syed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com> Co-authored-by: Lionel Villard <villard@us.ibm.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: ycyaw66 <497410282@qq.com> Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Ming Yang <minos.future@gmail.com> Co-authored-by: Adrián García García <adrigarvk8@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com> Co-authored-by: JaceyShao <65159281+JaceyShao@users.noreply.github.com> Co-authored-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Co-authored-by: Ricardo Decal <crypdick@users.noreply.github.com> Co-authored-by: Andrew Chan <andrewkchan.akc@gmail.com> Co-authored-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Zhiyu <zhiyuc@nvidia.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com> Co-authored-by: Junhao Li <streaver91@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: Hong Hanh <hanh.usth@gmail.com> Co-authored-by: Daniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Guy Stone <guys@spotify.com> Co-authored-by: yyweiss <70619747+yyweiss@users.noreply.github.com> Co-authored-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
334 lines
13 KiB
Python
334 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
|
marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|
MarlinWorkspace)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
|
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
|
GroupQuantScaleParameter,
|
|
PackedvLLMParameter)
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class HQQMarlinConfig(QuantizationConfig):
|
|
"""Config class for HQQ Marlin"""
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
skip_modules: Optional[list[str]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
assert group_size == 64, ("The only supported HQQ group size is "
|
|
"currently 64.")
|
|
assert weight_bits == 4, ("The only supported HQQ quantization "
|
|
"bitsize is currently 4.")
|
|
|
|
self.weight_bits = weight_bits
|
|
self.group_size = group_size
|
|
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
|
|
self.quant_type = scalar_types.uint4
|
|
self.skip_modules = skip_modules
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
|
|
f"group_size={self.group_size})")
|
|
|
|
@classmethod
|
|
def get_name(cls) -> QuantizationMethods:
|
|
return "hqq"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.half, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
|
|
wq_params = (config["quant_config"]["weight_quant_params"])
|
|
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
|
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
|
skip_modules = config["skip_modules"]
|
|
return cls(weight_bits, group_size, skip_modules)
|
|
|
|
def is_layer_skipped(self, prefix: str) -> bool:
|
|
# Split the prefix into its dot-separated components
|
|
components = prefix.split('.')
|
|
|
|
# Check if any of the skip modules exactly matches any component
|
|
return self.skip_modules is not None and any(
|
|
module_name in components for module_name in self.skip_modules)
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
if isinstance(layer, LinearBase):
|
|
if self.is_layer_skipped(prefix):
|
|
return UnquantizedLinearMethod()
|
|
return HQQMarlinMethod(self)
|
|
return None
|
|
|
|
|
|
# Empty HQQ parameter, will be ignored during loading
|
|
class HQQEmptyParameter(BasevLLMParameter):
|
|
|
|
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
pass
|
|
|
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
|
pass
|
|
|
|
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
pass
|
|
|
|
|
|
def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
|
raise ValueError("No loader provided for HQQ parameter!")
|
|
|
|
|
|
# HQQ packing creates issues with sharding - therefore, prior to loading, we
|
|
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
|
|
class HQQweightParameter(PackedvLLMParameter):
|
|
|
|
# unpack function from https://github.com/mobiusml/hqq
|
|
def unpack_4bit_u8(self,
|
|
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
|
|
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
|
|
|
|
dtype = torch.uint8
|
|
step = W_q.shape[0]
|
|
tmp = torch.empty([2 * step, W_q.shape[1]],
|
|
dtype=dtype,
|
|
device=W_q.device)
|
|
tmp[:step] = (W_q & 0b11110000) >> 4
|
|
tmp[step:] = W_q & 0b00001111
|
|
return tmp
|
|
|
|
def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
|
|
**kwargs):
|
|
super().__init__(packed_factor, packed_dim, None, **kwargs)
|
|
self.weight_bits = weight_bits
|
|
self.input_shape = self.shape[self.input_dim] * self.packed_factor
|
|
self.output_shape = self.shape[self.output_dim]
|
|
|
|
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
|
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
|
1, 0)
|
|
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
|
loaded_weight.shape[0],
|
|
loaded_weight.shape[1])
|
|
super().load_merged_column_weight(loaded_weight, **kwargs)
|
|
|
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
|
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
|
loaded_weight = loaded_weight.reshape(self.output_shape,
|
|
-1).transpose(1, 0)
|
|
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
|
loaded_weight.shape[0],
|
|
loaded_weight.shape[1])
|
|
super().load_row_parallel_weight(loaded_weight)
|
|
|
|
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
loaded_weight = self.unpack_4bit_u8(loaded_weight)
|
|
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
|
|
1, 0)
|
|
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
|
|
loaded_weight.shape[0],
|
|
loaded_weight.shape[1])
|
|
super().load_qkv_weight(loaded_weight, **kwargs)
|
|
|
|
|
|
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
|
|
# GPTQ shape (transposed - we transpose them too when processing weights).
|
|
class HQQZeroScaleParameter(GroupQuantScaleParameter):
|
|
|
|
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
|
super().load_merged_column_weight(loaded_weight, **kwargs)
|
|
|
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
|
loaded_weight = loaded_weight.reshape(self.shape[0], -1)
|
|
super().load_row_parallel_weight(loaded_weight)
|
|
|
|
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
|
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
|
|
super().load_qkv_weight(loaded_weight, **kwargs)
|
|
|
|
|
|
class HQQMarlinMethod(LinearMethodBase):
|
|
"""Linear method for HQQ Marlin.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
quant_config: HQQMarlinConfig,
|
|
):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
) -> None:
|
|
self.output_size_per_partition = sum(output_partition_sizes)
|
|
self.input_size_per_partition = input_size_per_partition
|
|
|
|
weight_loader = extra_weight_attrs.get("weight_loader", error_loader)
|
|
|
|
self.scales_and_zp_size = (input_size_per_partition //
|
|
self.quant_config.group_size)
|
|
|
|
qweight = HQQweightParameter(
|
|
data=torch.empty(
|
|
self.input_size_per_partition // self.quant_config.pack_factor,
|
|
self.output_size_per_partition,
|
|
dtype=torch.int32,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
packed_dim=0,
|
|
packed_factor=self.quant_config.pack_factor,
|
|
weight_bits=self.quant_config.weight_bits,
|
|
weight_loader=weight_loader)
|
|
|
|
zeros = HQQZeroScaleParameter(data=torch.empty(
|
|
self.output_size_per_partition,
|
|
self.scales_and_zp_size,
|
|
dtype=params_dtype,
|
|
),
|
|
input_dim=1,
|
|
output_dim=0,
|
|
weight_loader=weight_loader)
|
|
|
|
scales = HQQZeroScaleParameter(data=torch.empty(
|
|
self.output_size_per_partition,
|
|
self.scales_and_zp_size,
|
|
dtype=params_dtype,
|
|
),
|
|
input_dim=1,
|
|
output_dim=0,
|
|
weight_loader=weight_loader)
|
|
|
|
layer.register_parameter("W_q", qweight)
|
|
layer.register_parameter("zero", zeros)
|
|
layer.register_parameter("scale", scales)
|
|
|
|
# Ignore extra parameters in the HQQ model.
|
|
# To be added as needed.
|
|
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
|
|
"encoded_state_dict", "group_size", "nbits",
|
|
"offload_meta", "optimize", "packing",
|
|
"quant_scale", "quant_zero", "round_zero",
|
|
"shape", "stores_quant_config",
|
|
"unpack_view_dtype", "view_as_float")
|
|
for name in ignore_parameters:
|
|
layer.register_parameter(
|
|
name,
|
|
HQQEmptyParameter(data=torch.empty(0),
|
|
weight_loader=weight_loader))
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
dev = layer.W_q.device
|
|
|
|
# Repack to Marlin
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
|
marlin_w_q = ops.gptq_marlin_repack(
|
|
layer.W_q,
|
|
sort_indices,
|
|
self.input_size_per_partition,
|
|
self.output_size_per_partition,
|
|
self.quant_config.weight_bits,
|
|
).to(dev)
|
|
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
|
|
self.input_size_per_partition,
|
|
self.output_size_per_partition,
|
|
self.quant_config.group_size).to(dev)
|
|
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
|
|
self.input_size_per_partition,
|
|
self.output_size_per_partition,
|
|
self.quant_config.group_size).to(dev)
|
|
|
|
layer.g_idx = marlin_make_empty_g_idx(dev)
|
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
|
|
|
layer.marlin_qweight = marlin_w_q
|
|
layer.marlin_zeros = marlin_zp
|
|
layer.marlin_scales = marlin_s
|
|
|
|
if hasattr(layer, "bias") and layer.bias is not None:
|
|
layer.bias.data = marlin_permute_bias(layer.bias)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
workspace = MarlinWorkspace(self.output_size_per_partition,
|
|
GPTQ_MARLIN_MIN_THREAD_N,
|
|
GPTQ_MARLIN_MAX_PARALLEL)
|
|
|
|
scales = layer.marlin_scales
|
|
zeros = layer.marlin_zeros
|
|
orig_type = x.dtype
|
|
|
|
if orig_type != torch.float16:
|
|
x = x.to(torch.float16)
|
|
scales = scales.to(torch.float16)
|
|
zeros = zeros.to(torch.float16)
|
|
|
|
marlin_out = ops.gptq_marlin_gemm(
|
|
x,
|
|
None,
|
|
layer.marlin_qweight,
|
|
bias,
|
|
scales,
|
|
None,
|
|
zeros,
|
|
layer.g_idx,
|
|
layer.g_idx_sort_indices,
|
|
workspace.scratch,
|
|
scalar_types.uint4,
|
|
x.shape[0],
|
|
self.output_size_per_partition,
|
|
self.input_size_per_partition,
|
|
True, # is_k_full
|
|
False, # use atomic add
|
|
True, # use 32-bit reduce
|
|
True, # use float zp
|
|
)
|
|
|
|
if orig_type != torch.float16:
|
|
marlin_out = marlin_out.to(orig_type)
|
|
|
|
return marlin_out
|