mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 07:34:25 +08:00
[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
parent
00e5cbb967
commit
fccd532587
@ -10,10 +10,14 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
|
Fp8Config,
|
||||||
Fp8KVCacheMethod,
|
Fp8KVCacheMethod,
|
||||||
Fp8LinearMethod,
|
Fp8LinearMethod,
|
||||||
|
Fp8MoEMethod,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
|
||||||
|
# FP8 weight reloading does not support online quantization
|
||||||
|
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
|
||||||
|
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
|
||||||
|
# any postprocessing that is applied to the weights such as padding and repacking
|
||||||
|
# (excluding device sharding) must also be applied to the reloaded weights
|
||||||
|
#
|
||||||
|
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||||
|
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||||
|
def test_fp8_reloading(
|
||||||
|
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
|
||||||
|
):
|
||||||
|
if is_checkpoint_fp8_serialized is False:
|
||||||
|
pytest.skip("FP8 weight reloading does not support online quantization")
|
||||||
|
|
||||||
|
if method_cls is Fp8MoEMethod and weight_block_size is None:
|
||||||
|
pytest.skip(
|
||||||
|
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
|
||||||
|
"If this is your use case, consider using a restore function like #26327"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.device("cuda:0"):
|
||||||
|
config = Fp8Config(
|
||||||
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if method_cls is Fp8LinearMethod:
|
||||||
|
layer = torch.nn.Linear(1, 1)
|
||||||
|
method = method_cls(config)
|
||||||
|
method.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
input_size_per_partition=1,
|
||||||
|
output_partition_sizes=[1],
|
||||||
|
input_size=1,
|
||||||
|
output_size=1,
|
||||||
|
params_dtype=torch.bfloat16,
|
||||||
|
weight_loader=default_weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer = FusedMoE(
|
||||||
|
num_experts=1,
|
||||||
|
top_k=1,
|
||||||
|
hidden_size=1,
|
||||||
|
intermediate_size=1,
|
||||||
|
)
|
||||||
|
method = method_cls(config, layer)
|
||||||
|
method.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
num_experts=1,
|
||||||
|
hidden_size=1,
|
||||||
|
intermediate_size_per_partition=1,
|
||||||
|
params_dtype=torch.bfloat16,
|
||||||
|
weight_loader=default_weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
method.use_marlin = use_marlin
|
||||||
|
|
||||||
|
# capture weights format during loading
|
||||||
|
original_metadata = [
|
||||||
|
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
|
||||||
|
for name, param in layer.named_parameters()
|
||||||
|
]
|
||||||
|
|
||||||
|
# test loading
|
||||||
|
for name, shape, _ in original_metadata:
|
||||||
|
param = getattr(layer, name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||||
|
|
||||||
|
method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# test reloading works after loading
|
||||||
|
# assuming that no reshaping occurred
|
||||||
|
for name, shape, original_weight_loader in original_metadata:
|
||||||
|
param = getattr(layer, name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
assert weight_loader is original_weight_loader
|
||||||
|
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||||
|
|
||||||
|
method.process_weights_after_loading(layer)
|
||||||
|
|||||||
@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
|
|||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
size_k_first = False
|
size_k_first = False
|
||||||
|
|
||||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||||
layer.weight, layer.weight_scale_inv
|
layer.weight, layer.weight_scale_inv
|
||||||
)
|
)
|
||||||
# Delete the weight_scale_inv parameter to avoid confusion
|
|
||||||
# with the weight_scale parameter
|
# Update layer with new values
|
||||||
del layer.weight_scale_inv
|
replace_parameter(layer, "weight", weight.data)
|
||||||
|
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||||
|
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
|
||||||
weight = qweight.t()
|
|
||||||
|
|
||||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
|
||||||
# shards in a fused module
|
|
||||||
else:
|
else:
|
||||||
weight = layer.weight
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
weight_scale = layer.weight_scale
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||||
|
weight = qweight.t()
|
||||||
|
|
||||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||||
# requantize the logical shards as a single weight.
|
# shards in a fused module
|
||||||
if not self.use_marlin:
|
else:
|
||||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
weight = layer.weight
|
||||||
weight,
|
weight_scale = layer.weight_scale
|
||||||
weight_scale,
|
|
||||||
layer.logical_widths,
|
|
||||||
getattr(layer, "input_scale", None),
|
|
||||||
)
|
|
||||||
if self.act_q_static:
|
|
||||||
assert input_scale is not None
|
|
||||||
input_scale = input_scale.max()
|
|
||||||
weight = weight.t()
|
|
||||||
|
|
||||||
# Update layer with new values.
|
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
# requantize the logical shards as a single weight.
|
||||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
if not self.use_marlin:
|
||||||
layer.input_scale = (
|
weight, weight_scale, input_scale = (
|
||||||
Parameter(input_scale, requires_grad=False)
|
process_fp8_weight_tensor_strategy(
|
||||||
if input_scale is not None
|
weight,
|
||||||
else None
|
weight_scale,
|
||||||
)
|
layer.logical_widths,
|
||||||
|
getattr(layer, "input_scale", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.act_q_static:
|
||||||
|
assert input_scale is not None
|
||||||
|
input_scale = input_scale.max()
|
||||||
|
weight = weight.t()
|
||||||
|
|
||||||
|
# Update layer with new values.
|
||||||
|
replace_parameter(layer, "weight", weight.data)
|
||||||
|
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||||
|
|
||||||
|
if input_scale is not None:
|
||||||
|
replace_parameter(layer, "input_scale", input_scale)
|
||||||
|
else:
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(
|
prepare_fp8_layer_for_marlin(
|
||||||
@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return self.w8a8_block_fp8_linear.apply(
|
return self.w8a8_block_fp8_linear.apply(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
|
if self.block_quant:
|
||||||
|
weight_scale = layer.weight_scale_inv
|
||||||
|
else:
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
|
||||||
return apply_fp8_marlin_linear(
|
return apply_fp8_marlin_linear(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=weight_scale,
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return self.w8a8_block_fp8_linear.apply(
|
return self.w8a8_block_fp8_linear.apply(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale_inv,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||||
|
|
||||||
# torch.compile() cannot use Parameter subclasses.
|
# torch.compile() cannot use Parameter subclasses.
|
||||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
layer.w13_weight_scale_inv = Parameter(
|
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||||
w13_weight_scale_inv, requires_grad=False
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
)
|
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
|
||||||
layer.w2_weight_scale_inv = Parameter(
|
|
||||||
w2_weight_scale_inv, requires_grad=False
|
|
||||||
)
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
layer.w13_weight.data, layer.w2_weight.data
|
layer.w13_weight.data, layer.w2_weight.data
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||||
|
|
||||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||||
# it ahead of time for performance reasons.
|
# it ahead of time for performance reasons.
|
||||||
@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
# Re-initialize w13_scale because we directly quantize
|
||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
"w13_weight_scale",
|
||||||
torch.ones(
|
torch.ones(
|
||||||
layer.local_num_experts,
|
layer.local_num_experts,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=w13_weight.device,
|
device=w13_weight.device,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
|
||||||
)
|
)
|
||||||
for expert in range(layer.local_num_experts):
|
for expert in range(layer.local_num_experts):
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
)
|
)
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
layer.w13_weight, layer.w2_weight
|
layer.w13_weight, layer.w2_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||||
# If checkpoint is fp8, we need to handle that the
|
# If checkpoint is fp8, we need to handle that the
|
||||||
# MoE kernels require single activation scale and single weight
|
# MoE kernels require single activation scale and single weight
|
||||||
# scale for w13 per expert.
|
# scale for w13 per expert.
|
||||||
@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer."
|
"for each layer."
|
||||||
)
|
)
|
||||||
layer.w13_input_scale = torch.nn.Parameter(
|
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||||
layer.w13_input_scale.max(), requires_grad=False
|
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||||
)
|
|
||||||
layer.w2_input_scale = torch.nn.Parameter(
|
|
||||||
layer.w2_input_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||||
@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Reset the parameter
|
# Reset the parameter
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||||
w13_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
if w13_input_scale is not None:
|
if w13_input_scale is not None:
|
||||||
layer.w13_input_scale = torch.nn.Parameter(
|
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||||
w13_input_scale, requires_grad=False
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
)
|
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
|
||||||
w2_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
if w2_input_scale is not None:
|
if w2_input_scale is not None:
|
||||||
layer.w2_input_scale = torch.nn.Parameter(
|
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||||
w2_input_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
# We take the max then dequant and requant each expert.
|
# We take the max then dequant and requant each expert.
|
||||||
@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight, layer.w2_weight
|
layer.w13_weight, layer.w2_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
|
||||||
max_w13_scales, requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.flashinfer_moe_backend is not None:
|
if self.flashinfer_moe_backend is not None:
|
||||||
# NOTE: weights have to be swapped since the activation is
|
# NOTE: weights have to be swapped since the activation is
|
||||||
|
|||||||
@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# skip if there are no weights to process (for example, weight reloading)
|
||||||
|
if not hasattr(layer, "q_scale"):
|
||||||
|
assert not hasattr(layer, "k_scale")
|
||||||
|
assert not hasattr(layer, "v_scale")
|
||||||
|
assert not hasattr(layer, "prob_scale")
|
||||||
|
return
|
||||||
|
|
||||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||||
# regardless whether the kv-scale is available in the checkpoint.
|
# regardless whether the kv-scale is available in the checkpoint.
|
||||||
# No need to process kv scales after loading if we are going to
|
# No need to process kv scales after loading if we are going to
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (
|
|||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.utils import replace_parameter
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
@ -1404,12 +1405,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
|||||||
if should_use_deepgemm:
|
if should_use_deepgemm:
|
||||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||||
wq=layer.weight.data,
|
wq=layer.weight.data,
|
||||||
ws=layer.weight_scale.data,
|
ws=layer.weight_scale_inv.data,
|
||||||
quant_block_shape=tuple(layer.weight_block_size),
|
quant_block_shape=tuple(layer.weight_block_size),
|
||||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||||
)
|
)
|
||||||
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
|
replace_parameter(layer, "weight", dg_weight)
|
||||||
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
|
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
|
||||||
|
|
||||||
|
|
||||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
marlin_quant_input,
|
marlin_quant_input,
|
||||||
should_use_atomic_add_reduce,
|
should_use_atomic_add_reduce,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.utils import replace_parameter
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
|
|||||||
size_n=part_size_n,
|
size_n=part_size_n,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
)
|
)
|
||||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
replace_parameter(layer, "weight", marlin_qweight)
|
||||||
|
|
||||||
# WEIGHT SCALES
|
# WEIGHT SCALES
|
||||||
# Permute scales
|
# Permute scales
|
||||||
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
|
|||||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||||
elif "weight_scale_inv" in dir(layer):
|
elif "weight_scale_inv" in dir(layer):
|
||||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||||
del layer.weight_scale_inv
|
|
||||||
|
|
||||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||||
|
|
||||||
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
|
|||||||
)
|
)
|
||||||
if input_dtype != torch.float8_e4m3fn:
|
if input_dtype != torch.float8_e4m3fn:
|
||||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
if hasattr(layer, "weight_scale"):
|
||||||
|
replace_parameter(layer, "weight_scale", marlin_scales)
|
||||||
|
elif hasattr(layer, "weight_scale_inv"):
|
||||||
|
replace_parameter(layer, "weight_scale_inv", marlin_scales)
|
||||||
|
|
||||||
if hasattr(layer, "bias") and layer.bias is not None:
|
if hasattr(layer, "bias") and layer.bias is not None:
|
||||||
assert layer.bias.shape == (part_size_n,)
|
assert layer.bias.shape == (part_size_n,)
|
||||||
bias = marlin_permute_bias(layer.bias)
|
bias = marlin_permute_bias(layer.bias)
|
||||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
replace_parameter(layer, "bias", bias)
|
||||||
|
|
||||||
|
|
||||||
def prepare_moe_fp8_layer_for_marlin(
|
def prepare_moe_fp8_layer_for_marlin(
|
||||||
|
|||||||
@ -118,8 +118,11 @@ def requantize_with_max_scale(
|
|||||||
# from disk in this case. Skip requantization in this case (since)
|
# from disk in this case. Skip requantization in this case (since)
|
||||||
# we already are quantized with the single scale.
|
# we already are quantized with the single scale.
|
||||||
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||||
|
#
|
||||||
|
# Extra note: upon weight reloading weight_scale.ndim == 0
|
||||||
unfused_module_in_checkpoint = (
|
unfused_module_in_checkpoint = (
|
||||||
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
weight_scale.ndim != 0
|
||||||
|
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||||
)
|
)
|
||||||
|
|
||||||
# If unfused checkpoint, need requanize with the single scale.
|
# If unfused checkpoint, need requanize with the single scale.
|
||||||
|
|||||||
@ -50,6 +50,31 @@ def set_weight_attrs(
|
|||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Replace a parameter of a layer while maintaining the ability to reload the weight.
|
||||||
|
Called within implementations of the `process_weights_after_loading` method.
|
||||||
|
|
||||||
|
This function should not be called on weights which are tied/shared
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Layer containing parameter to replace
|
||||||
|
param_name: Name of parameter to replace
|
||||||
|
new_data: New data of the new parameter
|
||||||
|
"""
|
||||||
|
# should not be used on a tied/shared param
|
||||||
|
if isinstance(new_data, torch.nn.Parameter):
|
||||||
|
new_data = new_data.data
|
||||||
|
new_param = torch.nn.Parameter(new_data, requires_grad=False)
|
||||||
|
|
||||||
|
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
|
||||||
|
if old_param is not None and hasattr(old_param, "weight_loader"):
|
||||||
|
weight_loader = old_param.weight_loader
|
||||||
|
set_weight_attrs(new_param, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
setattr(layer, param_name, new_param)
|
||||||
|
|
||||||
|
|
||||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user