mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Misc] update fp8 to use vLLMParameter (#7437)
This commit is contained in:
parent
55d63b1211
commit
955b5191c9
@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||
@ -22,7 +22,7 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod"
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
|
||||
]
|
||||
|
||||
|
||||
@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
requires_grad=False)
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
|
||||
@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in self.qkv_idxs
|
||||
return self.qkv_idxs[shard_id]
|
||||
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def load_merged_column_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
def load_column_parallel_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int], **kwargs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user