mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[Misc] update fp8 to use vLLMParameter (#7437)
This commit is contained in:
parent
55d63b1211
commit
955b5191c9
@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
|||||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||||
awq, casperhansen/mixtral-instruct-awq, main
|
awq, casperhansen/mixtral-instruct-awq, main
|
||||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||||
|
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||||
@ -22,7 +22,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod"
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
# Special case for loading scales off disk, which often do not
|
||||||
|
# have a shape (such as in the case of AutoFP8).
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
assert loaded_weight.numel() == 1
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
def weight_loader_v2(self, param: BasevLLMParameter,
|
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||||
loaded_weight: torch.Tensor):
|
loaded_weight: torch.Tensor):
|
||||||
|
|
||||||
|
# Special case for loading scales off disk, which often do not
|
||||||
|
# have a shape (such as in the case of AutoFP8).
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
assert loaded_weight.numel() == 1
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
|
|||||||
@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
|
||||||
requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_hip, print_warning_once
|
from vllm.utils import is_hip, print_warning_once
|
||||||
@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
):
|
):
|
||||||
del input_size, output_size
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_dtype = (torch.float8_e4m3fn
|
weight_dtype = (torch.float8_e4m3fn
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||||
params_dtype)
|
params_dtype)
|
||||||
weight = Parameter(torch.empty(output_size_per_partition,
|
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=weight_dtype),
|
dtype=weight_dtype),
|
||||||
requires_grad=False)
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, {
|
|
||||||
**extra_weight_attrs,
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
# If checkpoint is serialized fp8, load them.
|
# If checkpoint is serialized fp8, load them.
|
||||||
# Otherwise, wait until process_weights_after_loading.
|
# Otherwise, wait until process_weights_after_loading.
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
**extra_weight_attrs)
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("weight_scale", scale)
|
layer.register_parameter("weight_scale", scale)
|
||||||
|
|
||||||
# INPUT ACTIVATION SCALE
|
# INPUT ACTIVATION SCALE
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
**extra_weight_attrs)
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
else:
|
else:
|
||||||
layer.register_parameter("input_scale", None)
|
layer.register_parameter("input_scale", None)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||||
|
requires_grad=False)
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||||
@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# If checkpoint is fp8, handle that there are N scales for N
|
# If checkpoint is fp8, handle that there are N scales for N
|
||||||
# shards in a fused module
|
# shards in a fused module
|
||||||
else:
|
else:
|
||||||
|
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||||
# so extend the weight scales to be channelwise.
|
# so extend the weight scales to be channelwise.
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
|
|||||||
@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
|||||||
if isinstance(shard_id, int):
|
if isinstance(shard_id, int):
|
||||||
return shard_id
|
return shard_id
|
||||||
|
|
||||||
|
# if not int, assume shard_id for qkv
|
||||||
|
# map to int and return
|
||||||
assert isinstance(shard_id, str)
|
assert isinstance(shard_id, str)
|
||||||
assert shard_id in self.qkv_idxs
|
assert shard_id in self.qkv_idxs
|
||||||
return self.qkv_idxs[shard_id]
|
return self.qkv_idxs[shard_id]
|
||||||
|
|
||||||
|
# For row parallel layers, no sharding needed
|
||||||
|
# load weight into parameter as is
|
||||||
|
def load_row_parallel_weight(self, *args, **kwargs):
|
||||||
|
super().load_row_parallel_weight(*args, **kwargs)
|
||||||
|
|
||||||
def load_merged_column_weight(self, *args, **kwargs):
|
def load_merged_column_weight(self, *args, **kwargs):
|
||||||
self._load_into_shard_id(*args, **kwargs)
|
self._load_into_shard_id(*args, **kwargs)
|
||||||
|
|
||||||
@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
|||||||
self._load_into_shard_id(*args, **kwargs)
|
self._load_into_shard_id(*args, **kwargs)
|
||||||
|
|
||||||
def load_column_parallel_weight(self, *args, **kwargs):
|
def load_column_parallel_weight(self, *args, **kwargs):
|
||||||
self._load_into_shard_id(*args, **kwargs)
|
super().load_row_parallel_weight(*args, **kwargs)
|
||||||
|
|
||||||
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
|
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
|
||||||
shard_id: Union[str, int], **kwargs):
|
shard_id: Union[str, int], **kwargs):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user