ComfyUI-KJNodes/nodes/lora_nodes.py

639 lines
26 KiB
Python

import torch
import comfy.model_management
import comfy.utils
import comfy.lora
import folder_paths
import os
import logging
from tqdm import tqdm
import numpy as np
device = comfy.model_management.get_torch_device()
CLAMP_QUANTILE = 0.99
def _resolve_weight_from_patches(patches, key):
base_weight, convert_func = patches[0]
weight_tensor = comfy.model_management.cast_to_device(
base_weight, torch.device("cpu"), torch.float32, copy=True
)
try:
weight_tensor = convert_func(weight_tensor, inplace=True)
except TypeError:
weight_tensor = convert_func(weight_tensor)
if len(patches) > 1:
weight_tensor = comfy.lora.calculate_weight(
patches[1:],
weight_tensor,
key,
intermediate_dtype=torch.float32,
original_weights={key: patches},
)
return weight_tensor
def _build_scaled_fp8_diff(finetuned_model, original_model, prefix, bias_diff):
finetuned_patches = finetuned_model.get_key_patches(prefix)
original_patches = original_model.get_key_patches(prefix)
common_keys = set(finetuned_patches.keys()).intersection(original_patches.keys())
diff_sd = {}
for key in common_keys:
is_weight = key.endswith(".weight")
is_bias = key.endswith(".bias")
if not is_weight and not (bias_diff and is_bias):
continue
ft_tensor = _resolve_weight_from_patches(finetuned_patches[key], key)
orig_tensor = _resolve_weight_from_patches(original_patches[key], key)
diff_sd[key] = ft_tensor.sub(orig_tensor)
return diff_sd
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
"""
Extracts LoRA weights from a weight difference tensor using SVD.
"""
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
diff_float = diff.float()
if algorithm == "svd_lowrank":
U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters)
U = U @ torch.diag(S)
Vh = V.t()
else:
#torch.linalg.svdvals()
U, S, Vh = torch.linalg.svd(diff_float)
# Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py
if "adaptive" in lora_type:
if lora_type == "adaptive_ratio":
min_s = torch.max(S) * adaptive_param
lora_rank = torch.sum(S > min_s).item()
elif lora_type == "adaptive_energy":
energy = torch.cumsum(S**2, dim=0)
total_energy = torch.sum(S**2)
threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95%
lora_rank = torch.sum(energy < threshold).item() + 1
elif lora_type == "adaptive_quantile":
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = adaptive_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
elif lora_type == "adaptive_fro":
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
lora_rank = max(1, min(lora_rank, len(S)))
else:
pass # Will print after capping
# Cap adaptive rank by the specified max rank
lora_rank = min(lora_rank, rank)
# Calculate and print actual fro percentage retained after capping
if lora_type == "adaptive_fro":
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
fro_percent = float(s_red_fro / s_fro)
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
else:
print(f"{key} Extracted LoRA rank: {lora_rank}")
else:
lora_rank = rank
lora_rank = max(1, lora_rank)
lora_rank = min(out_dim, in_dim, lora_rank)
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
if clamp_quantile:
dist = torch.cat([U.flatten(), Vh.flatten()])
if dist.numel() > 100_000:
# Sample 100,000 elements for quantile estimation
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
dist_sample = dist[idx]
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
else:
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, lora_rank, 1, 1)
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True, sd_override=None):
if sd_override is None:
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
model_diff.model.diffusion_model.cpu()
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
del model_diff
comfy.model_management.soft_empty_cache()
for k, v in sd.items():
if isinstance(v, torch.Tensor):
sd[k] = v.cpu()
else:
sd = sd_override
# Get total number of keys to process for progress bar
total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))])
# Create progress bar
progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})")
comfy_pbar = comfy.utils.ProgressBar(total_keys)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if weight_diff.ndim == 5:
logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed
progress_bar.update(1)
comfy_pbar.update(1)
continue
if lora_type != "full":
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
continue
try:
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
except Exception as e:
logging.warning(f"Could not generate lora weights for key {k}, error {e}")
else:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
progress_bar.close()
return output_sd
class LoraExtractKJ:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"finetuned_model": ("MODEL",),
"original_model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}),
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],),
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}),
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}),
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
"bias_diff": ("BOOLEAN", {"default": True}),
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}),
"clamp_quantile": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "KJNodes/lora"
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
if algorithm == "svd_lowrank" and lora_type != "standard":
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype]
model_diff = None
sd_override = None
scaled_fp8_ft = getattr(getattr(finetuned_model.model, "model_config", None), "scaled_fp8", None)
scaled_fp8_orig = getattr(getattr(original_model.model, "model_config", None), "scaled_fp8", None)
scaled_fp8_present = scaled_fp8_ft is not None or scaled_fp8_orig is not None
if scaled_fp8_present:
comfy.model_management.load_models_gpu([finetuned_model, original_model], force_patch_weights=True)
logging.info(
"LoraExtractKJ: detected scaled fp8 weights (finetuned=%s, original=%s); using high-precision diff path.",
scaled_fp8_ft is not None,
scaled_fp8_orig is not None,
)
sd_override = _build_scaled_fp8_diff(
finetuned_model, original_model, "diffusion_model.", bias_diff
)
comfy.model_management.soft_empty_cache()
else:
m = finetuned_model.clone()
kp = original_model.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - 1.0, 1.0)
model_diff = m
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
elif sd_override is not None:
output_sd = calc_lora_model(None, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile, sd_override=sd_override)
if "adaptive" in lora_type:
rank_str = f"{lora_type}_{adaptive_param:.2f}"
else:
rank_str = rank
output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraExtractKJ": LoraExtractKJ
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraExtractKJ": "LoraExtractKJ"
}
class LoraReduceRank:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}),
"dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}),
"dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}),
"output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}),
"verbose": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
EXPERIMENTAL = True
DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py"
CATEGORY = "KJNodes/lora"
def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose):
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True)
if output_dtype == "fp16":
save_dtype = torch.float16
elif output_dtype == "bf16":
save_dtype = torch.bfloat16
elif output_dtype == "fp32":
save_dtype = torch.float32
elif output_dtype == "match_original":
first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor))
save_dtype = lora_sd[first_weight_key].dtype
new_lora_sd = {}
for k, v in lora_sd.items():
new_lora_sd[k.replace(".default", "")] = v
del lora_sd
print("Resizing Lora...")
output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
if dynamic_method == "disabled":
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}"
metadata["ss_network_dim"] = str(new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
# cast to save_dtype before calculating hashes
for key in list(output_sd.keys()):
value = output_sd[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
output_sd[key] = value.to(save_dtype)
output_filename_prefix = "loras/" + lora_name
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir)
output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else ""
average_rank = str(int(np.mean(rank_list)))
rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}"
output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
print(f"Saving resized LoRA to {output_checkpoint}")
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {
"LoraExtractKJ": LoraExtractKJ
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraExtractKJ": "LoraExtractKJ"
}
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
# This version is based on
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py
MIN_SV = 1e-6
LORA_DOWN_UP_FORMATS = [
("lora_down", "lora_up"), # sd-scripts LoRA
("lora_A", "lora_B"), # PEFT LoRA
("down", "up"), # ControlLoRA
]
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
if weight.dtype != torch.float32:
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
if weight.dtype != torch.float32:
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict
def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
del lora_up, lora_down
return weight
def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale * new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale * new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
max_old_rank = None
new_alpha = None
verbose_str = "\n"
fro_list = []
rank_list = []
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
total_keys = len([k for k in lora_sd if k.endswith(".weight")])
pbar = comfy.utils.ProgressBar(total_keys)
for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"):
key_parts = key.split(".")
block_down_name = None
for _format in LORA_DOWN_UP_FORMATS:
# Currently we only match lora_down_name in the last two parts of key
# because ("down", "up") are general words and may appear in block_down_name
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
block_down_name = ".".join(key_parts[:-2])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = "." + key_parts[-1]
break
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
block_down_name = ".".join(key_parts[:-1])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = ""
break
if block_down_name is None:
# This parameter is not lora_down
continue
# Now weight_name can be ".weight" or ""
# Find corresponding lora_up and alpha
block_up_name = block_down_name
lora_down_weight = value
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
if weights_loaded:
conv2d = len(lora_down_weight.size()) == 4
old_rank = lora_down_weight.size()[0]
max_old_rank = max(max_old_rank or 0, old_rank)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / old_rank
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
tqdm.write(log_str)
verbose_str += log_str
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
rank_list.append(param_dict["new_rank"])
del param_dict
pbar.update(1)
if verbose:
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
return o_lora_sd, max_old_rank, new_alpha, rank_list