diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 98a66b6701ea..70d6ffc70367 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -17,4 +17,6 @@ awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main \ No newline at end of file +marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main +qqq, HandH1998/QQQ-Llama-3-8b-g128, main +qqq, HandH1998/QQQ-Llama-3-8b, main \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e5b40a64abc4..5f4ca90dd791 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -23,7 +23,7 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod" + "MarlinLinearMethod", "QQQLinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index be10cee2cf68..c3434214a1cd 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -8,7 +8,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) logger = init_logger(__name__) @@ -133,6 +136,7 @@ class QQQLinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): + weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}") @@ -170,90 +174,74 @@ class QQQLinearMethod(LinearMethodBase): "Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.tile_size, output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - "marlin_tile_size": self.quant_config.tile_size, - }, - ) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) - s_channel = Parameter( - torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - requires_grad=False, - ) - set_weight_attrs( - s_channel, - { - "input_dim": None, - "output_dim": 1, - }, - ) + s_channel = ChannelQuantScaleParameter(data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1) if self.quant_config.group_size == -1: - s_group = Parameter( - torch.tensor( - [], - device="cuda", - dtype=torch.half, - ), - requires_grad=False, + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, ) else: - s_group = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ), - requires_grad=False, + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, ) - set_weight_attrs( - s_group, - { - "input_dim": None if self.quant_config.group_size == -1 else 0, - "output_dim": - None if self.quant_config.group_size == -1 else 1, - }, - ) + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **s_group_attr) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel - workspace = Parameter(torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - requires_grad=False) + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) layer.register_parameter("B", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("s_channel", s_channel) - set_weight_attrs(s_channel, extra_weight_attrs) layer.register_parameter("s_group", s_group) - set_weight_attrs(s_group, extra_weight_attrs) layer.register_parameter("workspace", workspace) - set_weight_attrs(workspace, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) def apply( self,