[Misc] Update gptq_marlin to use new vLLMParameters (#7281)

This commit is contained in:
Dipika Sikka 2024-08-13 14:30:11 -04:00 committed by GitHub
parent 181abbc27d
commit fb377d7e74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 234 additions and 98 deletions

View File

@ -314,6 +314,16 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s -x lora/test_long_context.py
- label: Weight Loading Multiple GPU Test
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
##### multi gpus test #####
##### A100 test #####

View File

@ -0,0 +1,15 @@
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main

View File

@ -0,0 +1,32 @@
#!/bin/bash
SUCCESS=0
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
LOCAL_SUCCESS=0
IFS=', ' read -r -a array <<< "$MODEL_CONFIG"
echo "=== RUNNING MODEL: $MODEL_CONFIG ==="
export QUANTIZATION=${array[0]}
export MODEL_NAME=${array[1]}
export REVISION=${array[2]}
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
else
echo "=== FAILED MODEL: ${MODEL_CONFIG} ==="
fi
SUCCESS=$((SUCCESS + LOCAL_SUCCESS))
done
if [ "${SUCCESS}" -eq "0" ]; then
exit 0
else
exit 1
fi

View File

@ -0,0 +1,20 @@
import os
MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
def test_weight_loading(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
assert output

View File

@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
]
def adjust_marlin_shard(param, shard_size, shard_offset):

View File

@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype=params_dtype,
)
}
if self.group_size == -1:
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:

View File

@ -1,12 +1,11 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

View File

@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
]
logger = init_logger(__name__)
@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size = kwargs.get("shard_size")
if isinstance(
self,
PackedvLLMParameter) and self.packed_dim == self.output_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if isinstance(
self,
PackedvLLMParameter) and self.output_dim == self.packed_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter):
class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class for linear layer weights. Extends the
_ColumnvLLMParameter by adding loading functionality
for linear layers with row parallel functionality.
Requires an input dimension to be defined.
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self.data.copy_(loaded_weight)
class GroupQuantScaleParameter(ModelWeightParameter):
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Equivalent to ModelWeightParameter.
grouped quantization. Uses both column and row parallelism.
"""
pass
@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter):
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
@ -232,15 +242,11 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class PackedvLLMParameter(ModelWeightParameter):
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self,
@ -250,7 +256,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile = marlin_tile_size
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
@ -262,16 +268,70 @@ class PackedvLLMParameter(ModelWeightParameter):
return self._packed_factor
@property
def marlin_tile(self):
return self._marlin_tile
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
shard_size = shard_size // self.packed_factor
shard_offset = shard_offset // self.packed_factor
if self.marlin_tile is not None:
return self._adjust_shard_indexes_for_marlin(
shard_size, shard_offset)
return shard_size, shard_offset
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def __init__(self,
packed_factor: int,
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
@property
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
return _adjust_shard_indexes_for_marlin(
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset