[Misc] update fp8 to use vLLMParameter (#7437)

This commit is contained in:
Dipika Sikka 2024-08-22 08:36:18 -04:00 committed by GitHub
parent 55d63b1211
commit 955b5191c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 17 deletions

View File

@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main

View File

@ -22,7 +22,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod"
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
]
@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):

View File

@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:

View File

@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):