diff --git a/nodes/lora_nodes.py b/nodes/lora_nodes.py index fe9959c..0ec86d6 100644 --- a/nodes/lora_nodes.py +++ b/nodes/lora_nodes.py @@ -10,7 +10,7 @@ device = comfy.model_management.get_torch_device() CLAMP_QUANTILE = 0.99 -def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0): +def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True): """ Extracts LoRA weights from a weight difference tensor using SVD. """ @@ -59,25 +59,26 @@ def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptiv U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - dist = torch.cat([U.flatten(), Vh.flatten()]) - if dist.numel() > 100_000: - # Sample 100,000 elements for quantile estimation - idx = torch.randperm(dist.numel(), device=dist.device)[:100_000] - dist_sample = dist[idx] - hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE) - else: - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val + if clamp_quantile: + dist = torch.cat([U.flatten(), Vh.flatten()]) + if dist.numel() > 100_000: + # Sample 100,000 elements for quantile estimation + idx = torch.randperm(dist.numel(), device=dist.device)[:100_000] + dist_sample = dist[idx] + hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE) + else: + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) if conv2d: U = U.reshape(out_dim, lora_rank, 1, 1) Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1]) return (U, Vh) -def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0): +def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True): comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) model_diff.model.diffusion_model.cpu() sd = model_diff.model_state_dict(filter_prefix=prefix_model) @@ -110,7 +111,7 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora comfy_pbar.update(1) continue try: - out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param) + out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu() output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu() except Exception as e: @@ -146,6 +147,7 @@ class LoraExtractKJ: "output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}), "bias_diff": ("BOOLEAN", {"default": True}), "adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}), + "clamp_quantile": ("BOOLEAN", {"default": True}), }, } @@ -155,7 +157,7 @@ class LoraExtractKJ: CATEGORY = "KJNodes/lora" - def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param): + def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile): if algorithm == "svd_lowrank" and lora_type != "standard": raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.") @@ -170,7 +172,7 @@ class LoraExtractKJ: output_sd = {} if model_diff is not None: - output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param) + output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile) if "adaptive" in lora_type: rank_str = f"{lora_type}_{adaptive_param:.2f}" else: