diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 064dbb1feee8..bfbceb24aef9 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main +fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 50a50d98e9cc..4af954b74e8b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -22,7 +22,7 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", - "AWQLinearMethod", "GPTQMarlinLinearMethod" + "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod" ] @@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase): param_data.copy_(loaded_weight) def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) param.load_column_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): @@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase): def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_row_parallel_weight(loaded_weight=loaded_weight) def forward(self, input_): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fd7682a1c0f5..b10988b992ae 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, - create_per_tensor_scale_param, cutlass_fp8_supported, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_hip, print_warning_once @@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase): ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes @@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase): weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) - weight = Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - requires_grad=False) + + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": - scale = create_per_tensor_scale_param(output_partition_sizes, - **extra_weight_attrs) + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, @@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase): # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, # so extend the weight scales to be channelwise. if self.use_marlin: diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index c6cfab7892ef..326b6ae8fee6 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter): if isinstance(shard_id, int): return shard_id + # if not int, assume shard_id for qkv + # map to int and return assert isinstance(shard_id, str) assert shard_id in self.qkv_idxs return self.qkv_idxs[shard_id] + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + def load_merged_column_weight(self, *args, **kwargs): self._load_into_shard_id(*args, **kwargs) @@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter): self._load_into_shard_id(*args, **kwargs) def load_column_parallel_weight(self, *args, **kwargs): - self._load_into_shard_id(*args, **kwargs) + super().load_row_parallel_weight(*args, **kwargs) def _load_into_shard_id(self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs):