[Misc] Update gptq_marlin to use new vLLMParameters (#7281)

This commit is contained in:
Dipika Sikka 2024-08-13 14:30:11 -04:00 committed by GitHub
parent 181abbc27d
commit fb377d7e74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 234 additions and 98 deletions

View File

@ -314,6 +314,16 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s -x lora/test_long_context.py
- label: Weight Loading Multiple GPU Test
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
##### multi gpus test #####
##### A100 test #####

View File

@ -0,0 +1,15 @@
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main

View File

@ -0,0 +1,32 @@
#!/bin/bash
SUCCESS=0
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
LOCAL_SUCCESS=0
IFS=', ' read -r -a array <<< "$MODEL_CONFIG"
echo "=== RUNNING MODEL: $MODEL_CONFIG ==="
export QUANTIZATION=${array[0]}
export MODEL_NAME=${array[1]}
export REVISION=${array[2]}
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
else
echo "=== FAILED MODEL: ${MODEL_CONFIG} ==="
fi
SUCCESS=$((SUCCESS + LOCAL_SUCCESS))
done
if [ "${SUCCESS}" -eq "0" ]; then
exit 0
else
exit 1
fi

View File

@ -0,0 +1,20 @@
import os
MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
def test_weight_loading(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
assert output

View File

@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
]
def adjust_marlin_shard(param, shard_size, shard_offset):

View File

@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype=params_dtype,
)
}
if self.group_size == -1:
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:

View File

@ -1,12 +1,11 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
g_idx = Parameter(
torch.empty(
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
input_dim=0,
weight_loader=weight_loader)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

View File

@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
]
logger = init_logger(__name__)
@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size = kwargs.get("shard_size")
if isinstance(
self,
PackedvLLMParameter) and self.packed_dim == self.output_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if isinstance(
self,
PackedvLLMParameter) and self.output_dim == self.packed_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter):
class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class for linear layer weights. Extends the
_ColumnvLLMParameter by adding loading functionality
for linear layers with row parallel functionality.
Requires an input dimension to be defined.
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self.data.copy_(loaded_weight)
class GroupQuantScaleParameter(ModelWeightParameter):
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Equivalent to ModelWeightParameter.
grouped quantization. Uses both column and row parallelism.
"""
pass
@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self,
packed_factor: int,
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
@property
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile = marlin_tile_size
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
return self._packed_factor
@property
def marlin_tile(self):
return self._marlin_tile
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
shard_size = shard_size // self.packed_factor
shard_offset = shard_offset // self.packed_factor
if self.marlin_tile is not None:
return self._adjust_shard_indexes_for_marlin(
shard_size, shard_offset)
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
return _adjust_shard_indexes_for_marlin(
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset