Update model_loading.py

This commit is contained in:
kijai 2024-11-21 03:05:51 +02:00
parent 276a045a57
commit 895d3b83a4

View File

@ -717,35 +717,40 @@ class CogVideoXModelLoader:
#LoRAs #LoRAs
if lora is not None: if lora is not None:
from .lora_utils import merge_lora#, load_lora_into_transformer dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if "fun" in model.lower(): dimensionx_lora = False
for l in lora: adapter_list = []
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") adapter_weights = []
transformer = merge_lora(transformer, l["path"], l["strength"]) for l in lora:
else: if any(item in l["path"].lower() for item in dimensionx_loras):
adapter_list = [] dimensionx_lora = True
adapter_weights = [] fuse = True if l["fuse_lora"] else False
for l in lora: lora_sd = load_torch_file(l["path"])
fuse = True if l["fuse_lora"] else False lora_rank = None
lora_sd = load_torch_file(l["path"]) for key, val in lora_sd.items():
for key, val in lora_sd.items(): if "lora_B" in key:
if "lora_B" in key: lora_rank = val.shape[1]
lora_rank = val.shape[1] break
break if lora_rank is not None:
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0] adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength'] adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
#transformer = load_lora_into_transformer(lora, transformer)
adapter_list.append(adapter_name) adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight) adapter_weights.append(adapter_weight)
for l in lora: else:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) try: #Fun trainer LoRAs are loaded differently
from .lora_utils import merge_lora
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd)
except:
raise ValueError(f"Can't recognize LoRA {l['path']}")
if adapter_list:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse: if fuse:
lora_scale = 1 lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling if dimensionx_lora:
if any(item in lora[-1]["path"].lower() for item in dimension_loras):
lora_scale = lora_scale / lora_rank lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])