Allow disabling clamping on lora extract

This commit is contained in:
kijai 2025-08-23 16:21:30 +03:00
parent e2ce0843d1
commit ba9153cb06

View File

@ -10,7 +10,7 @@ device = comfy.model_management.get_torch_device()
CLAMP_QUANTILE = 0.99
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0):
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
"""
Extracts LoRA weights from a weight difference tensor using SVD.
"""
@ -59,25 +59,26 @@ def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptiv
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
if dist.numel() > 100_000:
# Sample 100,000 elements for quantile estimation
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
dist_sample = dist[idx]
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
else:
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
if clamp_quantile:
dist = torch.cat([U.flatten(), Vh.flatten()])
if dist.numel() > 100_000:
# Sample 100,000 elements for quantile estimation
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
dist_sample = dist[idx]
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
else:
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, lora_rank, 1, 1)
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0):
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
model_diff.model.diffusion_model.cpu()
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
@ -110,7 +111,7 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
comfy_pbar.update(1)
continue
try:
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param)
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
except Exception as e:
@ -146,6 +147,7 @@ class LoraExtractKJ:
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
"bias_diff": ("BOOLEAN", {"default": True}),
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}),
"clamp_quantile": ("BOOLEAN", {"default": True}),
},
}
@ -155,7 +157,7 @@ class LoraExtractKJ:
CATEGORY = "KJNodes/lora"
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param):
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
if algorithm == "svd_lowrank" and lora_type != "standard":
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
@ -170,7 +172,7 @@ class LoraExtractKJ:
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param)
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
if "adaptive" in lora_type:
rank_str = f"{lora_type}_{adaptive_param:.2f}"
else: