mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 14:38:44 +08:00
[Quantization] FP8 Weight Reloading for Quantized RL Rollout (#28480)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
parent
00e5cbb967
commit
fccd532587
@ -10,10 +10,14 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODELS = [
|
||||
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
|
||||
# FP8 weight reloading does not support online quantization
|
||||
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
|
||||
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
|
||||
# any postprocessing that is applied to the weights such as padding and repacking
|
||||
# (excluding device sharding) must also be applied to the reloaded weights
|
||||
#
|
||||
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||
def test_fp8_reloading(
|
||||
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
|
||||
):
|
||||
if is_checkpoint_fp8_serialized is False:
|
||||
pytest.skip("FP8 weight reloading does not support online quantization")
|
||||
|
||||
if method_cls is Fp8MoEMethod and weight_block_size is None:
|
||||
pytest.skip(
|
||||
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
|
||||
"If this is your use case, consider using a restore function like #26327"
|
||||
)
|
||||
|
||||
with torch.device("cuda:0"):
|
||||
config = Fp8Config(
|
||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
if method_cls is Fp8LinearMethod:
|
||||
layer = torch.nn.Linear(1, 1)
|
||||
method = method_cls(config)
|
||||
method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=1,
|
||||
output_partition_sizes=[1],
|
||||
input_size=1,
|
||||
output_size=1,
|
||||
params_dtype=torch.bfloat16,
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
else:
|
||||
layer = FusedMoE(
|
||||
num_experts=1,
|
||||
top_k=1,
|
||||
hidden_size=1,
|
||||
intermediate_size=1,
|
||||
)
|
||||
method = method_cls(config, layer)
|
||||
method.create_weights(
|
||||
layer=layer,
|
||||
num_experts=1,
|
||||
hidden_size=1,
|
||||
intermediate_size_per_partition=1,
|
||||
params_dtype=torch.bfloat16,
|
||||
weight_loader=default_weight_loader,
|
||||
)
|
||||
|
||||
method.use_marlin = use_marlin
|
||||
|
||||
# capture weights format during loading
|
||||
original_metadata = [
|
||||
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
|
||||
for name, param in layer.named_parameters()
|
||||
]
|
||||
|
||||
# test loading
|
||||
for name, shape, _ in original_metadata:
|
||||
param = getattr(layer, name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
# test reloading works after loading
|
||||
# assuming that no reshaping occurred
|
||||
for name, shape, original_weight_loader in original_metadata:
|
||||
param = getattr(layer, name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert weight_loader is original_weight_loader
|
||||
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
@ -94,7 +94,7 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
@ -548,46 +548,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
assert not self.act_q_static
|
||||
size_k_first = False
|
||||
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale_inv
|
||||
)
|
||||
# Delete the weight_scale_inv parameter to avoid confusion
|
||||
# with the weight_scale parameter
|
||||
del layer.weight_scale_inv
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
weight = qweight.t()
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
layer.input_scale = (
|
||||
Parameter(input_scale, requires_grad=False)
|
||||
if input_scale is not None
|
||||
else None
|
||||
)
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
if not self.use_marlin:
|
||||
weight, weight_scale, input_scale = (
|
||||
process_fp8_weight_tensor_strategy(
|
||||
weight,
|
||||
weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
)
|
||||
if self.act_q_static:
|
||||
assert input_scale is not None
|
||||
input_scale = input_scale.max()
|
||||
weight = weight.t()
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
if input_scale is not None:
|
||||
replace_parameter(layer, "input_scale", input_scale)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(
|
||||
@ -614,7 +618,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@ -643,10 +647,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
if self.block_quant:
|
||||
weight_scale = layer.weight_scale_inv
|
||||
else:
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
@ -660,7 +669,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@ -937,22 +946,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
@ -990,13 +995,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
@ -1005,16 +1011,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
@ -1035,12 +1042,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
@ -1054,22 +1057,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
@ -1093,12 +1088,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
replace_parameter(layer, "w13_weight_scale", max_w13_scales)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
# NOTE: weights have to be swapped since the activation is
|
||||
|
||||
@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# skip if there are no weights to process (for example, weight reloading)
|
||||
if not hasattr(layer, "q_scale"):
|
||||
assert not hasattr(layer, "k_scale")
|
||||
assert not hasattr(layer, "v_scale")
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
return
|
||||
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
@ -1404,12 +1405,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
if should_use_deepgemm:
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.weight.data,
|
||||
ws=layer.weight_scale.data,
|
||||
ws=layer.weight_scale_inv.data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False)
|
||||
replace_parameter(layer, "weight", dg_weight)
|
||||
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_quant_input,
|
||||
should_use_atomic_add_reduce,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin(
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
replace_parameter(layer, "weight", marlin_qweight)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Permute scales
|
||||
@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin(
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
elif "weight_scale_inv" in dir(layer):
|
||||
scales = layer.weight_scale_inv.to(layer.orig_dtype)
|
||||
del layer.weight_scale_inv
|
||||
|
||||
group_size = -1 if weight_block_size is None else weight_block_size[1]
|
||||
|
||||
@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin(
|
||||
)
|
||||
if input_dtype != torch.float8_e4m3fn:
|
||||
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
if hasattr(layer, "weight_scale"):
|
||||
replace_parameter(layer, "weight_scale", marlin_scales)
|
||||
elif hasattr(layer, "weight_scale_inv"):
|
||||
replace_parameter(layer, "weight_scale_inv", marlin_scales)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
assert layer.bias.shape == (part_size_n,)
|
||||
bias = marlin_permute_bias(layer.bias)
|
||||
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
|
||||
replace_parameter(layer, "bias", bias)
|
||||
|
||||
|
||||
def prepare_moe_fp8_layer_for_marlin(
|
||||
|
||||
@ -118,8 +118,11 @@ def requantize_with_max_scale(
|
||||
# from disk in this case. Skip requantization in this case (since)
|
||||
# we already are quantized with the single scale.
|
||||
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
#
|
||||
# Extra note: upon weight reloading weight_scale.ndim == 0
|
||||
unfused_module_in_checkpoint = (
|
||||
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||
weight_scale.ndim != 0
|
||||
and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||
)
|
||||
|
||||
# If unfused checkpoint, need requanize with the single scale.
|
||||
|
||||
@ -50,6 +50,31 @@ def set_weight_attrs(
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
|
||||
"""
|
||||
Replace a parameter of a layer while maintaining the ability to reload the weight.
|
||||
Called within implementations of the `process_weights_after_loading` method.
|
||||
|
||||
This function should not be called on weights which are tied/shared
|
||||
|
||||
Args:
|
||||
layer: Layer containing parameter to replace
|
||||
param_name: Name of parameter to replace
|
||||
new_data: New data of the new parameter
|
||||
"""
|
||||
# should not be used on a tied/shared param
|
||||
if isinstance(new_data, torch.nn.Parameter):
|
||||
new_data = new_data.data
|
||||
new_param = torch.nn.Parameter(new_data, requires_grad=False)
|
||||
|
||||
old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
|
||||
if old_param is not None and hasattr(old_param, "weight_loader"):
|
||||
weight_loader = old_param.weight_loader
|
||||
set_weight_attrs(new_param, {"weight_loader": weight_loader})
|
||||
|
||||
setattr(layer, param_name, new_param)
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user