import torch import comfy.model_management import comfy.utils import folder_paths import os import logging from tqdm import tqdm import numpy as np device = comfy.model_management.get_torch_device() CLAMP_QUANTILE = 0.99 def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True): """ Extracts LoRA weights from a weight difference tensor using SVD. """ conv2d = (len(diff.shape) == 4) kernel_size = None if not conv2d else diff.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) out_dim, in_dim = diff.size()[0:2] if conv2d: if conv2d_3x3: diff = diff.flatten(start_dim=1) else: diff = diff.squeeze() diff_float = diff.float() if algorithm == "svd_lowrank": U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters) U = U @ torch.diag(S) Vh = V.t() else: #torch.linalg.svdvals() U, S, Vh = torch.linalg.svd(diff_float) # Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py if "adaptive" in lora_type: if lora_type == "adaptive_ratio": min_s = torch.max(S) * adaptive_param lora_rank = torch.sum(S > min_s).item() elif lora_type == "adaptive_energy": energy = torch.cumsum(S**2, dim=0) total_energy = torch.sum(S**2) threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95% lora_rank = torch.sum(energy < threshold).item() + 1 elif lora_type == "adaptive_quantile": s_cum = torch.cumsum(S, dim=0) min_cum_sum = adaptive_param * torch.sum(S) lora_rank = torch.sum(s_cum < min_cum_sum).item() print(f"{key} Extracted LoRA rank: {lora_rank}") else: lora_rank = rank lora_rank = max(1, lora_rank) lora_rank = min(out_dim, in_dim, lora_rank) U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] if clamp_quantile: dist = torch.cat([U.flatten(), Vh.flatten()]) if dist.numel() > 100_000: # Sample 100,000 elements for quantile estimation idx = torch.randperm(dist.numel(), device=dist.device)[:100_000] dist_sample = dist[idx] hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE) else: hi_val = torch.quantile(dist, CLAMP_QUANTILE) low_val = -hi_val U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.reshape(out_dim, lora_rank, 1, 1) Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1]) return (U, Vh) def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True): comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) model_diff.model.diffusion_model.cpu() sd = model_diff.model_state_dict(filter_prefix=prefix_model) del model_diff comfy.model_management.soft_empty_cache() for k, v in sd.items(): if isinstance(v, torch.Tensor): sd[k] = v.cpu() # Get total number of keys to process for progress bar total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))]) # Create progress bar progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})") comfy_pbar = comfy.utils.ProgressBar(total_keys) for k in sd: if k.endswith(".weight"): weight_diff = sd[k] if weight_diff.ndim == 5: logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed progress_bar.update(1) comfy_pbar.update(1) continue if lora_type != "full": if weight_diff.ndim < 2: if bias_diff: output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() progress_bar.update(1) comfy_pbar.update(1) continue try: out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu() output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu() except Exception as e: logging.warning(f"Could not generate lora weights for key {k}, error {e}") else: output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu() progress_bar.update(1) comfy_pbar.update(1) elif bias_diff and k.endswith(".bias"): output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu() progress_bar.update(1) comfy_pbar.update(1) progress_bar.close() return output_sd class LoraExtractKJ: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "finetuned_model": ("MODEL",), "original_model": ("MODEL",), "filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), "lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy"],), "algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}), "lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}), "output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}), "bias_diff": ("BOOLEAN", {"default": True}), "adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}), "clamp_quantile": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "KJNodes/lora" def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile): if algorithm == "svd_lowrank" and lora_type != "standard": raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.") dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype] m = finetuned_model.clone() kp = original_model.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, - 1.0, 1.0) model_diff = m full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) output_sd = {} if model_diff is not None: output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) if "adaptive" in lora_type: rank_str = f"{lora_type}_{adaptive_param:.2f}" else: rank_str = rank output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) return {} NODE_CLASS_MAPPINGS = { "LoraExtractKJ": LoraExtractKJ } NODE_DISPLAY_NAME_MAPPINGS = { "LoraExtractKJ": "LoraExtractKJ" } class LoraReduceRank: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), "new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}), "dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}), "dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}), "output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}), "verbose": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True EXPERIMENTAL = True DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py" CATEGORY = "KJNodes/lora" def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose): lora_path = folder_paths.get_full_path("loras", lora_name) lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True) if output_dtype == "fp16": save_dtype = torch.float16 elif output_dtype == "bf16": save_dtype = torch.bfloat16 elif output_dtype == "fp32": save_dtype = torch.float32 elif output_dtype == "match_original": first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor)) save_dtype = lora_sd[first_weight_key].dtype new_lora_sd = {} for k, v in lora_sd.items(): new_lora_sd[k.replace(".default", "")] = v del lora_sd print("Resizing Lora...") output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose) # update metadata if metadata is None: metadata = {} comment = metadata.get("ss_training_comment", "") if dynamic_method == "disabled": metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}" metadata["ss_network_dim"] = str(new_rank) metadata["ss_network_alpha"] = str(new_alpha) else: metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}" metadata["ss_network_dim"] = "Dynamic" metadata["ss_network_alpha"] = "Dynamic" # cast to save_dtype before calculating hashes for key in list(output_sd.keys()): value = output_sd[key] if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: output_sd[key] = value.to(save_dtype) output_filename_prefix = "loras/" + lora_name full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir) output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else "" average_rank = str(int(np.mean(rank_list))) rank_str = new_rank if dynamic_method == "disabled" else average_rank output_checkpoint = f"{filename}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) print(f"Saving resized LoRA to {output_checkpoint}") comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata) return {} NODE_CLASS_MAPPINGS = { "LoraExtractKJ": LoraExtractKJ } NODE_DISPLAY_NAME_MAPPINGS = { "LoraExtractKJ": "LoraExtractKJ" } # Convert LoRA to different rank approximation (should only be used to go to lower rank) # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # Thanks to cloneofsimo # This version is based on # https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py MIN_SV = 1e-6 LORA_DOWN_UP_FORMATS = [ ("lora_down", "lora_up"), # sd-scripts LoRA ("lora_A", "lora_B"), # PEFT LoRA ("down", "up"), # ControlLoRA ] # Indexing functions def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0) / original_sum index = int(torch.searchsorted(cumulative_sums, target)) + 1 index = max(1, min(index, len(S) - 1)) return index def index_sv_fro(S, target): S_squared = S.pow(2) S_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = max(1, min(index, len(S) - 1)) return index def index_sv_ratio(S, target): max_sv = S[0] min_sv = max_sv / target index = int(torch.sum(S > min_sv).item()) index = max(1, min(index, len(S) - 1)) return index # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() if weight.dtype != torch.float32: weight = weight.to(torch.float32) U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() del U, S, Vh, weight return param_dict def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() if weight.dtype != torch.float32: weight = weight.to(torch.float32) U, S, Vh = torch.linalg.svd(weight.to(device)) param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight return param_dict def merge_conv(lora_down, lora_up, device): in_rank, in_size, kernel_size, k_ = lora_down.shape out_size, out_rank, _, _ = lora_up.shape assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" lora_down = lora_down.to(device) lora_up = lora_up.to(device) merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) del lora_up, lora_down return weight def merge_linear(lora_down, lora_up, device): in_rank, in_size = lora_down.shape out_size, out_rank = lora_up.shape assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" lora_down = lora_down.to(device) lora_up = lora_up.to(device) weight = lora_up @ lora_down del lora_up, lora_down return weight # Calculate new rank def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} if dynamic_method == "sv_ratio": # Calculate new dim and alpha based off ratio new_rank = index_sv_ratio(S, dynamic_param) + 1 new_alpha = float(scale * new_rank) elif dynamic_method == "sv_cumulative": # Calculate new dim and alpha based off cumulative sum new_rank = index_sv_cumulative(S, dynamic_param) + 1 new_alpha = float(scale * new_rank) elif dynamic_method == "sv_fro": # Calculate new dim and alpha based off sqrt sum of squares new_rank = index_sv_fro(S, dynamic_param) + 1 new_alpha = float(scale * new_rank) else: new_rank = rank new_alpha = float(scale * new_rank) if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 new_alpha = float(scale * new_rank) elif new_rank > rank: # cap max rank at rank new_rank = rank new_alpha = float(scale * new_rank) # Calculate resize info s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) S_squared = S.pow(2) s_fro = torch.sqrt(torch.sum(S_squared)) s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) fro_percent = float(s_red_fro / s_fro) param_dict["new_rank"] = new_rank param_dict["new_alpha"] = new_alpha param_dict["sum_retained"] = (s_rank) / s_sum param_dict["fro_retained"] = fro_percent param_dict["max_ratio"] = S[0] / S[new_rank - 1] return param_dict def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): max_old_rank = None new_alpha = None verbose_str = "\n" fro_list = [] rank_list = [] if dynamic_method: print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None o_lora_sd = lora_sd.copy() block_down_name = None block_up_name = None total_keys = len([k for k in lora_sd if k.endswith(".weight")]) pbar = comfy.utils.ProgressBar(total_keys) for key, value in tqdm(lora_sd.items()): key_parts = key.split(".") block_down_name = None for _format in LORA_DOWN_UP_FORMATS: # Currently we only match lora_down_name in the last two parts of key # because ("down", "up") are general words and may appear in block_down_name if len(key_parts) >= 2 and _format[0] == key_parts[-2]: block_down_name = ".".join(key_parts[:-2]) lora_down_name = "." + _format[0] lora_up_name = "." + _format[1] weight_name = "." + key_parts[-1] break if len(key_parts) >= 1 and _format[0] == key_parts[-1]: block_down_name = ".".join(key_parts[:-1]) lora_down_name = "." + _format[0] lora_up_name = "." + _format[1] weight_name = "" break if block_down_name is None: # This parameter is not lora_down continue # Now weight_name can be ".weight" or "" # Find corresponding lora_up and alpha block_up_name = block_down_name lora_down_weight = value lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None) lora_alpha = lora_sd.get(block_down_name + ".alpha", None) weights_loaded = lora_down_weight is not None and lora_up_weight is not None if weights_loaded: conv2d = len(lora_down_weight.size()) == 4 old_rank = lora_down_weight.size()[0] max_old_rank = max(max_old_rank or 0, old_rank) if lora_alpha is None: scale = 1.0 else: scale = lora_alpha / old_rank if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) else: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) if verbose: max_ratio = param_dict["max_ratio"] sum_retained = param_dict["sum_retained"] fro_retained = param_dict["fro_retained"] if not np.isnan(fro_retained): fro_list.append(float(fro_retained)) verbose_str += f"{block_down_name:75} | " verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" print(verbose_str) if verbose and dynamic_method: verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" else: verbose_str += "\n" new_alpha = param_dict["new_alpha"] o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous() o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous() o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) block_down_name = None block_up_name = None lora_down_weight = None lora_up_weight = None weights_loaded = False rank_list.append(param_dict["new_rank"]) del param_dict pbar.update(1) if verbose: print(verbose_str) print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") return o_lora_sd, max_old_rank, new_alpha, rank_list