Allow orbit LoRAs with Fun models as well

This commit is contained in:
kijai 2024-11-19 15:49:43 +02:00
parent f606d745e9
commit 6302e4b668
2 changed files with 40 additions and 45 deletions

View File

@ -84,7 +84,7 @@
}, },
"widgets_values": [ "widgets_values": [
49, 49,
50, 25,
6, 6,
458091243358272, 458091243358272,
"randomize", "randomize",
@ -268,7 +268,7 @@
}, },
"widgets_values": [ "widgets_values": [
49, 49,
false, true,
0 0
] ]
}, },

View File

@ -240,37 +240,37 @@ class DownloadAndLoadCogVideoModel:
#LoRAs #LoRAs
if lora is not None: if lora is not None:
from .lora_utils import merge_lora#, load_lora_into_transformer # from .lora_utils import merge_lora#, load_lora_into_transformer
if "fun" in model.lower(): # if "fun" in model.lower():
for l in lora: # for l in lora:
log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}") # log.info(f"Merging LoRA weights from {l['path']} with strength {l['strength']}")
transformer = merge_lora(transformer, l["path"], l["strength"]) # transformer = merge_lora(transformer, l["path"], l["strength"])
else: #else:
adapter_list = [] adapter_list = []
adapter_weights = [] adapter_weights = []
for l in lora: for l in lora:
fuse = True if l["fuse_lora"] else False fuse = True if l["fuse_lora"] else False
lora_sd = load_torch_file(l["path"]) lora_sd = load_torch_file(l["path"])
for key, val in lora_sd.items(): for key, val in lora_sd.items():
if "lora_B" in key: if "lora_B" in key:
lora_rank = val.shape[1] lora_rank = val.shape[1]
break break
log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}")
adapter_name = l['path'].split("/")[-1].split(".")[0] adapter_name = l['path'].split("/")[-1].split(".")[0]
adapter_weight = l['strength'] adapter_weight = l['strength']
pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) pipe.load_lora_weights(l['path'], weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name)
#transformer = load_lora_into_transformer(lora, transformer) #transformer = load_lora_into_transformer(lora, transformer)
adapter_list.append(adapter_name) adapter_list.append(adapter_name)
adapter_weights.append(adapter_weight) adapter_weights.append(adapter_weight)
for l in lora: for l in lora:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse: if fuse:
lora_scale = 1 lora_scale = 1
dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling dimension_loras = ["orbit", "dimensionx"] # for now dimensionx loras need scaling
if any(item in lora[-1]["path"].lower() for item in dimension_loras): if any(item in lora[-1]["path"].lower() for item in dimension_loras):
lora_scale = lora_scale / lora_rank lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"]) pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
if "fused" in attention_mode: if "fused" in attention_mode:
from diffusers.models.attention import Attention from diffusers.models.attention import Attention
@ -653,27 +653,22 @@ class CogVideoXModelLoader:
with open(transformer_config_path) as f: with open(transformer_config_path) as f:
transformer_config = json.load(f) transformer_config = json.load(f)
with init_empty_weights():
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]: if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
transformer_config["in_channels"] = 32 transformer_config["in_channels"] = 32
if "1_5" in model_type: if "1_5" in model_type:
transformer_config["ofs_embed_dim"] = 512 transformer_config["ofs_embed_dim"] = 512
elif "fun" in model_type:
transformer_config["in_channels"] = 33
else:
transformer_config["in_channels"] = 16
if "1_5" in model_type:
transformer_config["use_learned_positional_embeddings"] = False transformer_config["use_learned_positional_embeddings"] = False
transformer_config["patch_size_t"] = 2 transformer_config["patch_size_t"] = 2
transformer_config["patch_bias"] = False transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 300 transformer_config["sample_height"] = 300
transformer_config["sample_width"] = 300 transformer_config["sample_width"] = 300
elif "fun" in model_type:
transformer_config["in_channels"] = 33 with init_empty_weights():
else:
if "1_5" in model_type:
transformer_config["use_learned_positional_embeddings"] = False
transformer_config["patch_size_t"] = 2
transformer_config["patch_bias"] = False
#transformer_config["sample_height"] = 300 todo: check if this is needed
#transformer_config["sample_width"] = 300
transformer_config["in_channels"] = 16
transformer = CogVideoXTransformer3DModel.from_config(transformer_config) transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
#load weights #load weights