mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
Update model_loading.py
This commit is contained in:
parent
4a597f1955
commit
4c2ce52f57
@ -182,16 +182,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
#fp8
|
|
||||||
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
|
||||||
for name, param in transformer.named_parameters():
|
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
|
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
if fp8_transformer == "fastmode":
|
|
||||||
from .fp8_optimization import convert_fp8_linear
|
|
||||||
convert_fp8_linear(transformer, dtype)
|
|
||||||
|
|
||||||
with open(scheduler_path) as f:
|
with open(scheduler_path) as f:
|
||||||
scheduler_config = json.load(f)
|
scheduler_config = json.load(f)
|
||||||
@ -221,8 +212,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
adapter_list = []
|
adapter_list = []
|
||||||
adapter_weights = []
|
adapter_weights = []
|
||||||
for l in lora:
|
for l in lora:
|
||||||
if l["fuse_lora"]:
|
fuse = True if l["fuse_lora"] else False
|
||||||
fuse = True
|
|
||||||
lora_sd = load_torch_file(l["path"])
|
lora_sd = load_torch_file(l["path"])
|
||||||
for key, val in lora_sd.items():
|
for key, val in lora_sd.items():
|
||||||
if "lora_B" in key:
|
if "lora_B" in key:
|
||||||
@ -241,7 +231,16 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if fuse:
|
if fuse:
|
||||||
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
|
pipe.fuse_lora(lora_scale=1 / lora_rank, components=["transformer"])
|
||||||
|
|
||||||
|
#fp8
|
||||||
|
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
||||||
|
for name, param in pipe.transformer.named_parameters():
|
||||||
|
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
|
||||||
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if fp8_transformer == "fastmode":
|
||||||
|
from .fp8_optimization import convert_fp8_linear
|
||||||
|
convert_fp8_linear(pipe.transformer, dtype)
|
||||||
|
|
||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user