From bb205d809b467307b8ec3bb1a22680a4873187f8 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 Oct 2025 01:31:04 +0300 Subject: [PATCH] Update lora_nodes.py --- nodes/lora_nodes.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/nodes/lora_nodes.py b/nodes/lora_nodes.py index afdce81..926c6a3 100644 --- a/nodes/lora_nodes.py +++ b/nodes/lora_nodes.py @@ -48,7 +48,27 @@ def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptiv s_cum = torch.cumsum(S, dim=0) min_cum_sum = adaptive_param * torch.sum(S) lora_rank = torch.sum(s_cum < min_cum_sum).item() - print(f"{key} Extracted LoRA rank: {lora_rank}") + elif lora_type == "adaptive_fro": + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1 + lora_rank = max(1, min(lora_rank, len(S))) + else: + pass # Will print after capping + + # Cap adaptive rank by the specified max rank + lora_rank = min(lora_rank, rank) + + # Calculate and print actual fro percentage retained after capping + if lora_type == "adaptive_fro": + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank])) + fro_percent = float(s_red_fro / s_fro) + print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}") + else: + print(f"{key} Extracted LoRA rank: {lora_rank}") else: lora_rank = rank @@ -141,13 +161,13 @@ class LoraExtractKJ: "finetuned_model": ("MODEL",), "original_model": ("MODEL",), "filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), - "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), - "lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy"],), + "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}), + "lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],), "algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}), "lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}), "output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}), "bias_diff": ("BOOLEAN", {"default": True}), - "adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}), + "adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}), "clamp_quantile": ("BOOLEAN", {"default": True}), }, @@ -520,7 +540,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn fro_retained = param_dict["fro_retained"] if not np.isnan(fro_retained): fro_list.append(float(fro_retained)) - log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}" tqdm.write(log_str) verbose_str += log_str