From b1e5afc3e7843e993d3f7d40e57f0fecb9d137b5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 13 Aug 2024 17:08:20 -0400 Subject: [PATCH] [Misc] Update `awq` and `awq_marlin` to use `vLLMParameters` (#7422) --- tests/weight_loading/models.txt | 4 +- vllm/model_executor/layers/linear.py | 3 +- .../model_executor/layers/quantization/awq.py | 74 +++++++++---------- .../layers/quantization/awq_marlin.py | 74 +++++++++---------- 4 files changed, 73 insertions(+), 82 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 84ca8bcbd79e1..064dbb1feee83 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -12,4 +12,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main -compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main \ No newline at end of file +compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main +awq, casperhansen/mixtral-instruct-awq, main +awq_marlin, casperhansen/mixtral-instruct-awq, main diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cececea1c0af2..b4cc6daa3c41e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -21,7 +21,8 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ - "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod" + "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", + "AWQLinearMethod", "GPTQMarlinLinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ce2fa62ef565f..410b3cb5321cb 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,13 +1,13 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) class AWQConfig(QuantizationConfig): @@ -101,55 +101,51 @@ class AWQLinearMethod(LinearMethodBase): "weight shape. This can be caused by too large " "tensor parallel size.") - qweight = Parameter( - torch.empty( + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) - qzeros = Parameter( - torch.empty( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + qzeros = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) - scales = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(scales, { - "input_dim": 0, - "output_dim": 1, - }) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("qzeros", qzeros) - set_weight_attrs(qzeros, extra_weight_attrs) layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) def apply(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 1b0453c2bd6f8..eee6a8f7cff49 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,12 +1,10 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -151,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ) -> None: del output_size output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") # Normalize group_size if self.quant_config.group_size != -1: @@ -164,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase): input_size=input_size, group_size=group_size) - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) num_groups = input_size_per_partition // group_size - qzeros = Parameter( - torch.empty( + qzeros = PackedvLLMParameter( + data=torch.empty( num_groups, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) - scales = Parameter( - torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(scales, { - "input_dim": 0, - "output_dim": 1, - }) + scales = GroupQuantScaleParameter(data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("qzeros", qzeros) - set_weight_attrs(qzeros, extra_weight_attrs) layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -228,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase): # Here, we handle the repacking def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) # Allocate marlin workspace layer.workspace = marlin_make_workspace( @@ -278,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase): quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) + bias=bias) \ No newline at end of file