diff --git a/lora_utils.py b/lora_utils.py index 4733082..acf8b19 100644 --- a/lora_utils.py +++ b/lora_utils.py @@ -414,12 +414,15 @@ def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.flo alpha = 1.0 curr_layer.weight.data = curr_layer.weight.data.to(device) - if len(weight_up.shape) == 4: - curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), - weight_down.squeeze(3).squeeze(2)).unsqueeze( - 2).unsqueeze(3) - else: - curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + try: + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), + weight_down.squeeze(3).squeeze(2)).unsqueeze( + 2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + except: + print(f"Could not apply LoRA weight in layer {layer}") return transformer diff --git a/model_loading.py b/model_loading.py index 805b087..7a08057 100644 --- a/model_loading.py +++ b/model_loading.py @@ -589,7 +589,7 @@ class CogVideoXModelLoader: def INPUT_TYPES(s): return { "required": { - "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}), + "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}), @@ -821,7 +821,7 @@ class CogVideoXVAELoader: def INPUT_TYPES(s): return { "required": { - "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}), + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), }, "optional": { "precision": (["fp16", "fp32", "bf16"],