allow fused loras with fp8

This commit is contained in:
kijai 2024-11-11 23:36:48 +02:00
parent ea0273c8ec
commit 00fde5ebce

View File

@ -202,21 +202,6 @@ class DownloadAndLoadCogVideoModel:
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
if "1.5" in model:
params_to_keep.update({"norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"})
for name, param in transformer.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep)
with open(scheduler_path) as f:
scheduler_config = json.load(f)
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
@ -268,6 +253,21 @@ class DownloadAndLoadCogVideoModel:
lora_scale = lora_scale / lora_rank
pipe.fuse_lora(lora_scale=lora_scale, components=["transformer"])
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
if "1.5" in model:
params_to_keep.update({"norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"})
for name, param in pipe.transformer.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
if "1.5" in model:
params_to_keep.update({"ff"}) #otherwise NaNs
convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep)
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()