[Misc] Update gptq_marlin to use new vLLMParameters (#7281)

This commit is contained in:
Dipika Sikka 2024-08-13 14:30:11 -04:00 committed by GitHub
parent 181abbc27d
commit fb377d7e74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 234 additions and 98 deletions

View File

@ -314,6 +314,16 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s -x lora/test_long_context.py - pytest -v -s -x lora/test_long_context.py
- label: Weight Loading Multiple GPU Test
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
##### multi gpus test ##### ##### multi gpus test #####
##### A100 test ##### ##### A100 test #####

View File

@ -0,0 +1,15 @@
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main

View File

@ -0,0 +1,32 @@
#!/bin/bash
SUCCESS=0
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
LOCAL_SUCCESS=0
IFS=', ' read -r -a array <<< "$MODEL_CONFIG"
echo "=== RUNNING MODEL: $MODEL_CONFIG ==="
export QUANTIZATION=${array[0]}
export MODEL_NAME=${array[1]}
export REVISION=${array[2]}
pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
else
echo "=== FAILED MODEL: ${MODEL_CONFIG} ==="
fi
SUCCESS=$((SUCCESS + LOCAL_SUCCESS))
done
if [ "${SUCCESS}" -eq "0" ]; then
exit 0
else
exit 1
fi

View File

@ -0,0 +1,20 @@
import os
MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
def test_weight_loading(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
output = model.generate_greedy("Hello world!", max_tokens=20)
print(output)
assert output

View File

@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"] WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
]
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(param, shard_size, shard_offset):

View File

@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype=params_dtype, dtype=params_dtype,
) )
} }
if self.group_size == -1: if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0, weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args) **weight_scale_args)
else: else:

View File

@ -1,12 +1,11 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size # Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size = input_size_per_partition // group_size scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights # Quantized weights
qweight = Parameter( qweight = PackedvLLMParameter(
torch.empty( data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=0,
qweight, packed_factor=self.quant_config.pack_factor,
{ weight_loader=weight_loader)
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order # Activation order
g_idx = Parameter( g_idx = RowvLLMParameter(data=torch.empty(
torch.empty( input_size_per_partition,
input_size_per_partition, dtype=torch.int32,
dtype=torch.int32, ),
), input_dim=0,
requires_grad=False, weight_loader=weight_loader)
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
# Scales qzeros_args = {
scales = Parameter( "data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty( torch.empty(
scales_and_zp_size, scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, "weight_loader":
) weight_loader
set_weight_attrs( }
qzeros, weight_scale_args = {
{ "data":
**extra_weight_attrs, torch.empty(
"input_dim": scales_and_zp_input_dim, scales_and_zp_size,
"output_dim": 1, output_size_per_partition,
"packed_dim": 1, dtype=params_dtype,
"pack_factor": self.quant_config.pack_factor, ),
}, "weight_loader":
) weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx) layer.register_parameter("g_idx", g_idx)
@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace # Allocate marlin workspace
layer.workspace = marlin_make_workspace( layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device) layer.output_size_per_partition, device)

View File

@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__ = [ __all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter" "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
] ]
logger = init_logger(__name__) logger = init_logger(__name__)
@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size = kwargs.get("shard_size") shard_size = kwargs.get("shard_size")
if isinstance( if isinstance(
self, self,
PackedvLLMParameter) and self.packed_dim == self.output_dim: (PackedColumnParameter,
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing( shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size) shard_offset=shard_offset, shard_size=shard_size)
@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if isinstance( if isinstance(
self, self,
PackedvLLMParameter) and self.output_dim == self.packed_dim: (PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing( shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size) shard_offset=shard_offset, shard_size=shard_size)
@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter): class RowvLLMParameter(BasevLLMParameter):
""" """
Parameter class for linear layer weights. Extends the Parameter class defining weight_loading functionality
_ColumnvLLMParameter by adding loading functionality (load_row_parallel_weight) for parameters being loaded
for linear layers with row parallel functionality. into linear layers with row parallel functionality.
Requires an input dimension to be defined. Requires an input_dim to be defined.
""" """
def __init__(self, input_dim: int, **kwargs): def __init__(self, input_dim: int, **kwargs):
@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
class GroupQuantScaleParameter(ModelWeightParameter): class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
""" """
Parameter class for weight scales loaded for weights with Parameter class for weight scales loaded for weights with
grouped quantization. Equivalent to ModelWeightParameter. grouped quantization. Uses both column and row parallelism.
""" """
pass pass
@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self,
packed_factor: int,
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
@property
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
class PackedvLLMParameter(ModelWeightParameter): class PackedvLLMParameter(ModelWeightParameter):
""" """
Parameter for model weights which are packed on disk. Parameter for model weights which are packed on disk.
@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**kwargs): **kwargs):
self._packed_factor = packed_factor self._packed_factor = packed_factor
self._packed_dim = packed_dim self._packed_dim = packed_dim
self._marlin_tile = marlin_tile_size self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs) super().__init__(**kwargs)
@property @property
@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
return self._packed_factor return self._packed_factor
@property @property
def marlin_tile(self): def marlin_tile_size(self):
return self._marlin_tile return self._marlin_tile_size
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
shard_size = shard_size // self.packed_factor return _adjust_shard_indexes_for_packing(
shard_offset = shard_offset // self.packed_factor shard_size=shard_size,
if self.marlin_tile is not None: shard_offset=shard_offset,
return self._adjust_shard_indexes_for_marlin( packed_factor=self.packed_factor,
shard_size, shard_offset) marlin_tile_size=self.marlin_tile_size)
return shard_size, shard_offset
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
return _adjust_shard_indexes_for_marlin(
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset