fix LoRA for I2V model when using fp8

This commit is contained in:
kijai 2024-10-14 01:05:35 +03:00
parent 0f5d67a9a9
commit 09ed641575

View File

@ -363,19 +363,10 @@ class DownloadAndLoadCogVideoModel:
#fp8
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
if "2b" in model:
for name, param in transformer.named_parameters():
if name != "pos_embedding":
param.data = param.data.to(torch.float8_e4m3fn)
elif "I2V" in model:
for name, param in transformer.named_parameters():
if "patch_embed" not in name:
param.data = param.data.to(torch.float8_e4m3fn)
else:
#transformer.to(torch.float8_e4m3fn)
for name, param in transformer.named_parameters():
if "lora" not in name:
param.data = param.data.to(torch.float8_e4m3fn)
for name, param in transformer.named_parameters():
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear