mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-15 07:44:30 +08:00
Allow disabling clamping on lora extract
This commit is contained in:
parent
e2ce0843d1
commit
ba9153cb06
@ -10,7 +10,7 @@ device = comfy.model_management.get_torch_device()
|
|||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0):
|
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
|
||||||
"""
|
"""
|
||||||
Extracts LoRA weights from a weight difference tensor using SVD.
|
Extracts LoRA weights from a weight difference tensor using SVD.
|
||||||
"""
|
"""
|
||||||
@ -59,25 +59,26 @@ def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptiv
|
|||||||
U = U @ torch.diag(S)
|
U = U @ torch.diag(S)
|
||||||
Vh = Vh[:lora_rank, :]
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
if clamp_quantile:
|
||||||
if dist.numel() > 100_000:
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
# Sample 100,000 elements for quantile estimation
|
if dist.numel() > 100_000:
|
||||||
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
|
# Sample 100,000 elements for quantile estimation
|
||||||
dist_sample = dist[idx]
|
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
|
||||||
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
|
dist_sample = dist[idx]
|
||||||
else:
|
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
else:
|
||||||
low_val = -hi_val
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
Vh = Vh.clamp(low_val, hi_val)
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
if conv2d:
|
if conv2d:
|
||||||
U = U.reshape(out_dim, lora_rank, 1, 1)
|
U = U.reshape(out_dim, lora_rank, 1, 1)
|
||||||
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
|
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
return (U, Vh)
|
return (U, Vh)
|
||||||
|
|
||||||
|
|
||||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0):
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True):
|
||||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
model_diff.model.diffusion_model.cpu()
|
model_diff.model.diffusion_model.cpu()
|
||||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
@ -110,7 +111,7 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param)
|
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
||||||
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
|
||||||
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -146,6 +147,7 @@ class LoraExtractKJ:
|
|||||||
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
|
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
|
||||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||||
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}),
|
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values."}),
|
||||||
|
"clamp_quantile": ("BOOLEAN", {"default": True}),
|
||||||
},
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -155,7 +157,7 @@ class LoraExtractKJ:
|
|||||||
|
|
||||||
CATEGORY = "KJNodes/lora"
|
CATEGORY = "KJNodes/lora"
|
||||||
|
|
||||||
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param):
|
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
|
||||||
if algorithm == "svd_lowrank" and lora_type != "standard":
|
if algorithm == "svd_lowrank" and lora_type != "standard":
|
||||||
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
|
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
|
||||||
|
|
||||||
@ -170,7 +172,7 @@ class LoraExtractKJ:
|
|||||||
|
|
||||||
output_sd = {}
|
output_sd = {}
|
||||||
if model_diff is not None:
|
if model_diff is not None:
|
||||||
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param)
|
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
||||||
if "adaptive" in lora_type:
|
if "adaptive" in lora_type:
|
||||||
rank_str = f"{lora_type}_{adaptive_param:.2f}"
|
rank_str = f"{lora_type}_{adaptive_param:.2f}"
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user