[Misc] Update qqq to use vLLMParameters (#7805)

This commit is contained in:
Dipika Sikka 2024-08-26 15:16:15 -04:00 committed by GitHub
parent 2deb029d11
commit 665304092d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 64 deletions

View File

@ -17,4 +17,6 @@ awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main

View File

@ -23,7 +23,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod"
"MarlinLinearMethod", "QQQLinearMethod"
]

View File

@ -8,7 +8,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
logger = init_logger(__name__)
@ -133,6 +136,7 @@ class QQQLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
@ -170,90 +174,74 @@ class QQQLinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
s_channel = Parameter(
torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
requires_grad=False,
)
set_weight_attrs(
s_channel,
{
"input_dim": None,
"output_dim": 1,
},
)
s_channel = ChannelQuantScaleParameter(data=torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
weight_loader=weight_loader,
output_dim=1)
if self.quant_config.group_size == -1:
s_group = Parameter(
torch.tensor(
[],
device="cuda",
dtype=torch.half,
),
requires_grad=False,
s_group_data = torch.tensor(
[],
device="cuda",
dtype=torch.half,
)
else:
s_group = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
),
requires_grad=False,
s_group_data = torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
)
set_weight_attrs(
s_group,
{
"input_dim": None if self.quant_config.group_size == -1 else 0,
"output_dim":
None if self.quant_config.group_size == -1 else 1,
},
)
s_group_attr = {"data": s_group_data, "weight_loader": weight_loader}
if self.quant_config.group_size == -1:
s_group = BasevLLMParameter(**s_group_attr)
else:
s_group = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**s_group_attr)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s_channel", s_channel)
set_weight_attrs(s_channel, extra_weight_attrs)
layer.register_parameter("s_group", s_group)
set_weight_attrs(s_group, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = Parameter(layer.B.data, requires_grad=False)
layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False)
layer.s_group = Parameter(layer.s_group.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,