diff --git a/nodes.py b/nodes.py index 82c593b..e1caadd 100644 --- a/nodes.py +++ b/nodes.py @@ -372,7 +372,10 @@ class DownloadAndLoadCogVideoModel: if "patch_embed" not in name: param.data = param.data.to(torch.float8_e4m3fn) else: - transformer.to(torch.float8_e4m3fn) + #transformer.to(torch.float8_e4m3fn) + for name, param in transformer.named_parameters(): + if "lora" not in name: + param.data = param.data.to(torch.float8_e4m3fn) if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear