diff --git a/lora_utils.py b/lora_utils.py index acf8b19..a06c7e3 100644 --- a/lora_utils.py +++ b/lora_utils.py @@ -406,8 +406,8 @@ def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.flo else: temp_name = layer_infos.pop(0) - weight_up = elems['lora_up.weight'].to(dtype) - weight_down = elems['lora_down.weight'].to(dtype) + weight_up = elems['lora_up.weight'].to(dtype).to(device) + weight_down = elems['lora_down.weight'].to(dtype).to(device) if 'alpha' in elems.keys(): alpha = elems['alpha'].item() / weight_up.shape[1] else: diff --git a/model_loading.py b/model_loading.py index c2883ff..dfa6d43 100644 --- a/model_loading.py +++ b/model_loading.py @@ -241,51 +241,51 @@ class DownloadAndLoadCogVideoModel: #LoRAs if lora is not None: - dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling - dimensionx_lora = False - adapter_list = [] - adapter_weights = [] - for l in lora: - if any(item in l["path"].lower() for item in dimensionx_loras): - dimensionx_lora = True - fuse = True if l["fuse_lora"] else False - lora_sd = load_torch_file(l["path"]) - lora_rank = None - for key, val in lora_sd.items(): - if "lora_B" in key: - lora_rank = val.shape[1] - break - if lora_rank is not None: - log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") - adapter_name = l['path'].split("/")[-1].split(".")[0] - adapter_weight = l['strength'] - pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) - - adapter_list.append(adapter_name) - adapter_weights.append(adapter_weight) - else: - try: #Fun trainer LoRAs are loaded differently - from .lora_utils import merge_lora - log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") - transformer = merge_lora(transformer, l["path"], l["strength"]) - except: - raise ValueError(f"Can't recognize LoRA {l['path']}") - if adapter_list: - pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) - if fuse: - lora_scale = 1 - if dimensionx_lora: - lora_scale = lora_scale / lora_rank - pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) + dimensionx_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling + dimensionx_lora = False + adapter_list = [] + adapter_weights = [] + for l in lora: + if any(item in l["path"].lower() for item in dimensionx_loras): + dimensionx_lora = True + fuse = True if l["fuse_lora"] else False + lora_sd = load_torch_file(l["path"]) + lora_rank = None + for key, val in lora_sd.items(): + if "lora_B" in key: + lora_rank = val.shape[1] + break + if lora_rank is not None: + log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") + adapter_name = l['path'].split("/")[-1].split(".")[0] + adapter_weight = l['strength'] + pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) + + adapter_list.append(adapter_name) + adapter_weights.append(adapter_weight) + else: + try: #Fun trainer LoRAs are loaded differently + from .lora_utils import merge_lora + log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") + pipe.transformer = merge_lora(pipe.transformer, l["path"], l["strength"], device=transformer_load_device, state_dict=lora_sd) + except: + raise ValueError(f"Can't recognize LoRA {l['path']}") + if adapter_list: + pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) + if fuse: + lora_scale = 1 + if dimensionx_lora: + lora_scale = lora_scale / lora_rank + pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) if "fused" in attention_mode: from diffusers.models.attention import Attention - transformer.fuse_qkv_projections = True - for module in transformer.modules(): + pipe.transformer.fuse_qkv_projections = True + for module in pipe.transformer.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) - transformer.attention_mode = attention_mode + pipe.transformer.attention_mode = attention_mode if compile_args is not None: pipe.transformer.to(memory_format=torch.channels_last) diff --git a/nodes.py b/nodes.py index c9cc493..853b254 100644 --- a/nodes.py +++ b/nodes.py @@ -602,6 +602,7 @@ class CogVideoSampler: def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None): + mm.unload_all_models() mm.soft_empty_cache() model_name = model.get("model_name", "")