Allow loading the "Rewards" LoRAs into 1.5 as well (for what it's worth)

This commit is contained in:
kijai 2024-11-20 19:18:40 +02:00
parent 573150de28
commit e187cfe22f
2 changed files with 11 additions and 8 deletions

View File

@ -414,12 +414,15 @@ def merge_lora(transformer, lora_path, multiplier, device='cpu', dtype=torch.flo
alpha = 1.0
curr_layer.weight.data = curr_layer.weight.data.to(device)
if len(weight_up.shape) == 4:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
weight_down.squeeze(3).squeeze(2)).unsqueeze(
2).unsqueeze(3)
else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
try:
if len(weight_up.shape) == 4:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
weight_down.squeeze(3).squeeze(2)).unsqueeze(
2).unsqueeze(3)
else:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
except:
print(f"Could not apply LoRA weight in layer {layer}")
return transformer

View File

@ -589,7 +589,7 @@ class CogVideoXModelLoader:
def INPUT_TYPES(s):
return {
"required": {
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load.",}),
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
"base_precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
"quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'torchao_fp8dq', "torchao_fp8dqrow", "torchao_int8dq", "torchao_fp6"], {"default": 'disabled', "tooltip": "optional quantization method"}),
@ -821,7 +821,7 @@ class CogVideoXVAELoader:
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],