mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-08 20:34:35 +08:00
639 lines
26 KiB
Python
639 lines
26 KiB
Python
import torch
|
|
import comfy.model_management
|
|
import comfy.utils
|
|
import comfy.lora
|
|
import folder_paths
|
|
import os
|
|
import logging
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
device = comfy.model_management.get_torch_device()
|
|
|
|
CLAMP_QUANTILE = 0.99
|
|
|
|
|
|
def _resolve_weight_from_patches(patches, key):
|
|
base_weight, convert_func = patches[0]
|
|
weight_tensor = comfy.model_management.cast_to_device(
|
|
base_weight, torch.device("cpu"), torch.float32, copy=True
|
|
)
|
|
try:
|
|
weight_tensor = convert_func(weight_tensor, inplace=True)
|
|
except TypeError:
|
|
weight_tensor = convert_func(weight_tensor)
|
|
|
|
if len(patches) > 1:
|
|
weight_tensor = comfy.lora.calculate_weight(
|
|
patches[1:],
|
|
weight_tensor,
|
|
key,
|
|
intermediate_dtype=torch.float32,
|
|
original_weights={key: patches},
|
|
)
|
|
|
|
return weight_tensor
|
|
|
|
|
|
def _build_scaled_fp8_diff(finetuned_model, original_model, prefix, bias_diff):
|
|
finetuned_patches = finetuned_model.get_key_patches(prefix)
|
|
original_patches = original_model.get_key_patches(prefix)
|
|
|
|
common_keys = set(finetuned_patches.keys()).intersection(original_patches.keys())
|
|
diff_sd = {}
|
|
|
|
for key in common_keys:
|
|
is_weight = key.endswith(".weight")
|
|
is_bias = key.endswith(".bias")
|
|
|
|
if not is_weight and not (bias_diff and is_bias):
|
|
continue
|
|
|
|
ft_tensor = _resolve_weight_from_patches(finetuned_patches[key], key)
|
|
orig_tensor = _resolve_weight_from_patches(original_patches[key], key)
|
|
|
|
diff_sd[key] = ft_tensor.sub(orig_tensor)
|
|
|
|
return diff_sd
|
|
|
|
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
|
|
"""
|
|
Extracts LoRA weights from a weight difference tensor using SVD.
|
|
"""
|
|
conv2d = (len(diff.shape) == 4)
|
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
|
out_dim, in_dim = diff.size()[0:2]
|
|
|
|
if conv2d:
|
|
if conv2d_3x3:
|
|
diff = diff.flatten(start_dim=1)
|
|
else:
|
|
diff = diff.squeeze()
|
|
|
|
diff_float = diff.float()
|
|
if algorithm == "svd_lowrank":
|
|
U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters)
|
|
U = U @ torch.diag(S)
|
|
Vh = V.t()
|
|
else:
|
|
#torch.linalg.svdvals()
|
|
U, S, Vh = torch.linalg.svd(diff_float)
|
|
# Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py
|
|
if "adaptive" in lora_type:
|
|
if lora_type == "adaptive_ratio":
|
|
min_s = torch.max(S) * adaptive_param
|
|
lora_rank = torch.sum(S > min_s).item()
|
|
elif lora_type == "adaptive_energy":
|
|
energy = torch.cumsum(S**2, dim=0)
|
|
total_energy = torch.sum(S**2)
|
|
threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95%
|
|
lora_rank = torch.sum(energy < threshold).item() + 1
|
|
elif lora_type == "adaptive_quantile":
|
|
s_cum = torch.cumsum(S, dim=0)
|
|
min_cum_sum = adaptive_param * torch.sum(S)
|
|
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
|
elif lora_type == "adaptive_fro":
|
|
S_squared = S.pow(2)
|
|
S_fro_sq = float(torch.sum(S_squared))
|
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
|
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
|
|
lora_rank = max(1, min(lora_rank, len(S)))
|
|
else:
|
|
pass # Will print after capping
|
|
|
|
# Cap adaptive rank by the specified max rank
|
|
lora_rank = min(lora_rank, rank)
|
|
|
|
# Calculate and print actual fro percentage retained after capping
|
|
if lora_type == "adaptive_fro":
|
|
S_squared = S.pow(2)
|
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
|
|
fro_percent = float(s_red_fro / s_fro)
|
|
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
|
|
else:
|
|
print(f"{key} Extracted LoRA rank: {lora_rank}")
|
|
else:
|
|
lora_rank = rank
|
|
|
|
lora_rank = max(1, lora_rank)
|
|
lora_rank = min(out_dim, in_dim, lora_rank)
|
|
|
|
U = U[:, :lora_rank]
|
|
S = S[:lora_rank]
|
|
U = U @ torch.diag(S)
|
|
Vh = Vh[:lora_rank, :]
|
|
|
|
if clamp_quantile:
|
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
|
if dist.numel() > 100_000:
|
|
# Sample 100,000 elements for quantile estimation
|
|
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
|
|
dist_sample = dist[idx]
|
|
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
|
|
else:
|
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
|
low_val = -hi_val
|
|
|
|
U = U.clamp(low_val, hi_val)
|
|
Vh = Vh.clamp(low_val, hi_val)
|
|
if conv2d:
|
|
U = U.reshape(out_dim, lora_rank, 1, 1)
|
|
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
|
|
return (U, Vh)
|
|
|
|
|
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True, sd_override=None):
|
|
if sd_override is None:
|
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
|
model_diff.model.diffusion_model.cpu()
|
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
|
del model_diff
|
|
comfy.model_management.soft_empty_cache()
|
|
for k, v in sd.items():
|
|
if isinstance(v, torch.Tensor):
|
|
sd[k] = v.cpu()
|
|
else:
|
|
sd = sd_override
|
|
|
|
# Get total number of keys to process for progress bar
|
|
total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))])
|
|
|
|
# Create progress bar
|
|
progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})")
|
|
comfy_pbar = comfy.utils.ProgressBar(total_keys)
|
|
|
|
for k in sd:
|
|
if k.endswith(".weight"):
|
|
weight_diff = sd[k]
|
|
if weight_diff.ndim == 5:
|
|
logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed
|
|
progress_bar.update(1)
|
|
comfy_pbar.update(1)
|
|
continue
|
|
if lora_type != "full":
|
|
if weight_diff.ndim < 2:
|
|
if bias_diff:
|
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
|
|
progress_bar.update(1)
|
|
comfy_pbar.update(1)
|
|
continue
|
|
try:
|
|
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
|
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
|
|
except Exception as e:
|
|
logging.warning(f"Could not generate lora weights for key {k}, error {e}")
|
|
else:
|
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
|
|
|
|
progress_bar.update(1)
|
|
comfy_pbar.update(1)
|
|
|
|
elif bias_diff and k.endswith(".bias"):
|
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu()
|
|
progress_bar.update(1)
|
|
comfy_pbar.update(1)
|
|
progress_bar.close()
|
|
return output_sd
|
|
|
|
class LoraExtractKJ:
|
|
def __init__(self):
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{
|
|
"finetuned_model": ("MODEL",),
|
|
"original_model": ("MODEL",),
|
|
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
|
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}),
|
|
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],),
|
|
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}),
|
|
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}),
|
|
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
|
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
|
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}),
|
|
"clamp_quantile": ("BOOLEAN", {"default": True}),
|
|
},
|
|
|
|
}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
|
|
CATEGORY = "KJNodes/lora"
|
|
|
|
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
|
|
if algorithm == "svd_lowrank" and lora_type != "standard":
|
|
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
|
|
|
|
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype]
|
|
|
|
model_diff = None
|
|
sd_override = None
|
|
|
|
scaled_fp8_ft = getattr(getattr(finetuned_model.model, "model_config", None), "scaled_fp8", None)
|
|
scaled_fp8_orig = getattr(getattr(original_model.model, "model_config", None), "scaled_fp8", None)
|
|
scaled_fp8_present = scaled_fp8_ft is not None or scaled_fp8_orig is not None
|
|
|
|
if scaled_fp8_present:
|
|
comfy.model_management.load_models_gpu([finetuned_model, original_model], force_patch_weights=True)
|
|
logging.info(
|
|
"LoraExtractKJ: detected scaled fp8 weights (finetuned=%s, original=%s); using high-precision diff path.",
|
|
scaled_fp8_ft is not None,
|
|
scaled_fp8_orig is not None,
|
|
)
|
|
sd_override = _build_scaled_fp8_diff(
|
|
finetuned_model, original_model, "diffusion_model.", bias_diff
|
|
)
|
|
comfy.model_management.soft_empty_cache()
|
|
else:
|
|
m = finetuned_model.clone()
|
|
kp = original_model.get_key_patches("diffusion_model.")
|
|
for k in kp:
|
|
m.add_patches({k: kp[k]}, - 1.0, 1.0)
|
|
model_diff = m
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
|
|
output_sd = {}
|
|
if model_diff is not None:
|
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
|
elif sd_override is not None:
|
|
output_sd = calc_lora_model(None, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile, sd_override=sd_override)
|
|
if "adaptive" in lora_type:
|
|
rank_str = f"{lora_type}_{adaptive_param:.2f}"
|
|
else:
|
|
rank_str = rank
|
|
output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors"
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
|
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
|
return {}
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LoraExtractKJ": LoraExtractKJ
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LoraExtractKJ": "LoraExtractKJ"
|
|
}
|
|
|
|
class LoraReduceRank:
|
|
def __init__(self):
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{
|
|
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
|
"new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}),
|
|
"dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}),
|
|
"dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}),
|
|
"output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}),
|
|
"verbose": ("BOOLEAN", {"default": True}),
|
|
},
|
|
|
|
}
|
|
RETURN_TYPES = ()
|
|
FUNCTION = "save"
|
|
OUTPUT_NODE = True
|
|
EXPERIMENTAL = True
|
|
DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py"
|
|
|
|
CATEGORY = "KJNodes/lora"
|
|
|
|
def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose):
|
|
|
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
|
lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True)
|
|
|
|
if output_dtype == "fp16":
|
|
save_dtype = torch.float16
|
|
elif output_dtype == "bf16":
|
|
save_dtype = torch.bfloat16
|
|
elif output_dtype == "fp32":
|
|
save_dtype = torch.float32
|
|
elif output_dtype == "match_original":
|
|
first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor))
|
|
save_dtype = lora_sd[first_weight_key].dtype
|
|
|
|
new_lora_sd = {}
|
|
for k, v in lora_sd.items():
|
|
new_lora_sd[k.replace(".default", "")] = v
|
|
del lora_sd
|
|
print("Resizing Lora...")
|
|
output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose)
|
|
|
|
# update metadata
|
|
if metadata is None:
|
|
metadata = {}
|
|
|
|
comment = metadata.get("ss_training_comment", "")
|
|
|
|
if dynamic_method == "disabled":
|
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}"
|
|
metadata["ss_network_dim"] = str(new_rank)
|
|
metadata["ss_network_alpha"] = str(new_alpha)
|
|
else:
|
|
metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}"
|
|
metadata["ss_network_dim"] = "Dynamic"
|
|
metadata["ss_network_alpha"] = "Dynamic"
|
|
|
|
# cast to save_dtype before calculating hashes
|
|
for key in list(output_sd.keys()):
|
|
value = output_sd[key]
|
|
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
|
output_sd[key] = value.to(save_dtype)
|
|
|
|
output_filename_prefix = "loras/" + lora_name
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir)
|
|
output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else ""
|
|
average_rank = str(int(np.mean(rank_list)))
|
|
rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}"
|
|
output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors"
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
print(f"Saving resized LoRA to {output_checkpoint}")
|
|
|
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata)
|
|
return {}
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LoraExtractKJ": LoraExtractKJ
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LoraExtractKJ": "LoraExtractKJ"
|
|
}
|
|
|
|
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
|
# Thanks to cloneofsimo
|
|
|
|
# This version is based on
|
|
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py
|
|
|
|
MIN_SV = 1e-6
|
|
|
|
LORA_DOWN_UP_FORMATS = [
|
|
("lora_down", "lora_up"), # sd-scripts LoRA
|
|
("lora_A", "lora_B"), # PEFT LoRA
|
|
("down", "up"), # ControlLoRA
|
|
]
|
|
|
|
# Indexing functions
|
|
def index_sv_cumulative(S, target):
|
|
original_sum = float(torch.sum(S))
|
|
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
|
index = max(1, min(index, len(S) - 1))
|
|
|
|
return index
|
|
|
|
|
|
def index_sv_fro(S, target):
|
|
S_squared = S.pow(2)
|
|
S_fro_sq = float(torch.sum(S_squared))
|
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
|
index = max(1, min(index, len(S) - 1))
|
|
|
|
return index
|
|
|
|
|
|
def index_sv_ratio(S, target):
|
|
max_sv = S[0]
|
|
min_sv = max_sv / target
|
|
index = int(torch.sum(S > min_sv).item())
|
|
index = max(1, min(index, len(S) - 1))
|
|
|
|
return index
|
|
|
|
|
|
# Modified from Kohaku-blueleaf's extract/merge functions
|
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
|
out_size, in_size, kernel_size, _ = weight.size()
|
|
if weight.dtype != torch.float32:
|
|
weight = weight.to(torch.float32)
|
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
|
|
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
|
lora_rank = param_dict["new_rank"]
|
|
|
|
U = U[:, :lora_rank]
|
|
S = S[:lora_rank]
|
|
U = U @ torch.diag(S)
|
|
Vh = Vh[:lora_rank, :]
|
|
|
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
|
del U, S, Vh, weight
|
|
return param_dict
|
|
|
|
|
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
|
out_size, in_size = weight.size()
|
|
|
|
if weight.dtype != torch.float32:
|
|
weight = weight.to(torch.float32)
|
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
|
|
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
|
lora_rank = param_dict["new_rank"]
|
|
|
|
U = U[:, :lora_rank]
|
|
S = S[:lora_rank]
|
|
U = U @ torch.diag(S)
|
|
Vh = Vh[:lora_rank, :]
|
|
|
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
|
del U, S, Vh, weight
|
|
return param_dict
|
|
|
|
|
|
def merge_conv(lora_down, lora_up, device):
|
|
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
|
out_size, out_rank, _, _ = lora_up.shape
|
|
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
|
|
|
lora_down = lora_down.to(device)
|
|
lora_up = lora_up.to(device)
|
|
|
|
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
|
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
|
del lora_up, lora_down
|
|
return weight
|
|
|
|
|
|
def merge_linear(lora_down, lora_up, device):
|
|
in_rank, in_size = lora_down.shape
|
|
out_size, out_rank = lora_up.shape
|
|
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
|
|
|
lora_down = lora_down.to(device)
|
|
lora_up = lora_up.to(device)
|
|
|
|
weight = lora_up @ lora_down
|
|
del lora_up, lora_down
|
|
return weight
|
|
|
|
|
|
# Calculate new rank
|
|
|
|
|
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
|
param_dict = {}
|
|
|
|
if dynamic_method == "sv_ratio":
|
|
# Calculate new dim and alpha based off ratio
|
|
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
|
new_alpha = float(scale * new_rank)
|
|
|
|
elif dynamic_method == "sv_cumulative":
|
|
# Calculate new dim and alpha based off cumulative sum
|
|
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
|
new_alpha = float(scale * new_rank)
|
|
|
|
elif dynamic_method == "sv_fro":
|
|
# Calculate new dim and alpha based off sqrt sum of squares
|
|
new_rank = index_sv_fro(S, dynamic_param) + 1
|
|
new_alpha = float(scale * new_rank)
|
|
else:
|
|
new_rank = rank
|
|
new_alpha = float(scale * new_rank)
|
|
|
|
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
|
new_rank = 1
|
|
new_alpha = float(scale * new_rank)
|
|
elif new_rank > rank: # cap max rank at rank
|
|
new_rank = rank
|
|
new_alpha = float(scale * new_rank)
|
|
|
|
# Calculate resize info
|
|
s_sum = torch.sum(torch.abs(S))
|
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
|
|
|
S_squared = S.pow(2)
|
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
|
fro_percent = float(s_red_fro / s_fro)
|
|
|
|
param_dict["new_rank"] = new_rank
|
|
param_dict["new_alpha"] = new_alpha
|
|
param_dict["sum_retained"] = (s_rank) / s_sum
|
|
param_dict["fro_retained"] = fro_percent
|
|
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
|
|
|
return param_dict
|
|
|
|
|
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
|
max_old_rank = None
|
|
new_alpha = None
|
|
verbose_str = "\n"
|
|
fro_list = []
|
|
rank_list = []
|
|
|
|
if dynamic_method:
|
|
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
|
|
|
lora_down_weight = None
|
|
lora_up_weight = None
|
|
|
|
o_lora_sd = lora_sd.copy()
|
|
block_down_name = None
|
|
block_up_name = None
|
|
|
|
total_keys = len([k for k in lora_sd if k.endswith(".weight")])
|
|
|
|
pbar = comfy.utils.ProgressBar(total_keys)
|
|
for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"):
|
|
key_parts = key.split(".")
|
|
block_down_name = None
|
|
for _format in LORA_DOWN_UP_FORMATS:
|
|
# Currently we only match lora_down_name in the last two parts of key
|
|
# because ("down", "up") are general words and may appear in block_down_name
|
|
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
|
|
block_down_name = ".".join(key_parts[:-2])
|
|
lora_down_name = "." + _format[0]
|
|
lora_up_name = "." + _format[1]
|
|
weight_name = "." + key_parts[-1]
|
|
break
|
|
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
|
|
block_down_name = ".".join(key_parts[:-1])
|
|
lora_down_name = "." + _format[0]
|
|
lora_up_name = "." + _format[1]
|
|
weight_name = ""
|
|
break
|
|
|
|
if block_down_name is None:
|
|
# This parameter is not lora_down
|
|
continue
|
|
|
|
# Now weight_name can be ".weight" or ""
|
|
# Find corresponding lora_up and alpha
|
|
block_up_name = block_down_name
|
|
lora_down_weight = value
|
|
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
|
|
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
|
|
|
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
|
|
|
if weights_loaded:
|
|
|
|
conv2d = len(lora_down_weight.size()) == 4
|
|
old_rank = lora_down_weight.size()[0]
|
|
max_old_rank = max(max_old_rank or 0, old_rank)
|
|
|
|
|
|
if lora_alpha is None:
|
|
scale = 1.0
|
|
else:
|
|
scale = lora_alpha / old_rank
|
|
|
|
if conv2d:
|
|
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
|
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
|
else:
|
|
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
|
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
|
|
|
if verbose:
|
|
max_ratio = param_dict["max_ratio"]
|
|
sum_retained = param_dict["sum_retained"]
|
|
fro_retained = param_dict["fro_retained"]
|
|
if not np.isnan(fro_retained):
|
|
fro_list.append(float(fro_retained))
|
|
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
|
|
tqdm.write(log_str)
|
|
verbose_str += log_str
|
|
|
|
if verbose and dynamic_method:
|
|
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
|
else:
|
|
verbose_str += "\n"
|
|
|
|
new_alpha = param_dict["new_alpha"]
|
|
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
|
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
|
|
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
|
|
|
block_down_name = None
|
|
block_up_name = None
|
|
lora_down_weight = None
|
|
lora_up_weight = None
|
|
weights_loaded = False
|
|
rank_list.append(param_dict["new_rank"])
|
|
del param_dict
|
|
pbar.update(1)
|
|
|
|
if verbose:
|
|
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
|
return o_lora_sd, max_old_rank, new_alpha, rank_list
|