Update model_loading.py

This commit is contained in:
kijai 2024-11-08 17:30:51 +02:00
parent 4a597f1955
commit 4c2ce52f57

View File

@ -182,16 +182,7 @@ class DownloadAndLoadCogVideoModel:
if block_edit is not None: if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit) transformer = remove_specific_blocks(transformer, block_edit)
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
for name, param in transformer.named_parameters():
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, dtype)
with open(scheduler_path) as f: with open(scheduler_path) as f:
scheduler_config = json.load(f) scheduler_config = json.load(f)
@ -221,8 +212,7 @@ class DownloadAndLoadCogVideoModel:
adapter_list = [] adapter_list = []
adapter_weights = [] adapter_weights = []
for l in lora: for l in lora:
if l["fuse_lora"]: fuse = True if l["fuse_lora"] else False
fuse = True
lora_sd = load_torch_file(l["path"]) lora_sd = load_torch_file(l["path"])
for key, val in lora_sd.items(): for key, val in lora_sd.items():
if "lora_B" in key: if "lora_B" in key:
@ -240,8 +230,17 @@ class DownloadAndLoadCogVideoModel:
pipe.set_adapters(adapter_list, adapter_weights=adapter_weights) pipe.set_adapters(adapter_list, adapter_weights=adapter_weights)
if fuse: if fuse:
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"]) pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
for name, param in pipe.transformer.named_parameters():
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(pipe.transformer, dtype)
if enable_sequential_cpu_offload: if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()