mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:45:01 +08:00
[Misc] Update fbgemmfp8 to use vLLMParameters (#7972)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
61f4a93d14
commit
e16fa99a6a
@ -26,7 +26,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||||
"TPUInt8LinearMethod", "GPTQLinearMethod"
|
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, create_per_channel_scale_param)
|
apply_fp8_linear)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -85,6 +86,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
del input_size, output_size
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
@ -95,20 +97,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
layer.orig_dtype = params_dtype
|
layer.orig_dtype = params_dtype
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight = Parameter(torch.empty(output_size_per_partition,
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
input_size_per_partition,
|
output_size_per_partition,
|
||||||
dtype=torch.float8_e4m3fn),
|
input_size_per_partition,
|
||||||
requires_grad=False)
|
dtype=torch.float8_e4m3fn),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, {
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
})
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
weight_scale = create_per_channel_scale_param(output_partition_sizes,
|
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
|
||||||
**extra_weight_attrs)
|
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
weight_scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
# INPUT SCALE UPPER BOUND
|
# INPUT SCALE UPPER BOUND
|
||||||
@ -118,6 +121,11 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
layer.input_scale_ub = input_scale_ub
|
layer.input_scale_ub = input_scale_ub
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
# required by torch.compile
|
||||||
|
layer.weight_scale = Parameter(layer.weight_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
|||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
def create_per_tensor_scale_param(
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
**extra_weight_attrs,
|
|
||||||
) -> Parameter:
|
|
||||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
scale[:] = torch.finfo(torch.float32).min
|
|
||||||
set_weight_attrs(scale, {
|
|
||||||
"needs_scalar_to_array": True,
|
|
||||||
**extra_weight_attrs
|
|
||||||
})
|
|
||||||
return scale
|
|
||||||
|
|
||||||
|
|
||||||
def create_per_channel_scale_param(output_partition_sizes: List[int],
|
|
||||||
**extra_weight_attrs) -> Parameter:
|
|
||||||
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
scale[:] = torch.finfo(torch.float32).min
|
|
||||||
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
|
|
||||||
return scale
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_channelwise(
|
def convert_to_channelwise(
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user