diff --git a/nodes.py b/nodes.py index fa2aebb..59fce7d 100644 --- a/nodes.py +++ b/nodes.py @@ -363,19 +363,10 @@ class DownloadAndLoadCogVideoModel: #fp8 if fp8_transformer == "enabled" or fp8_transformer == "fastmode": - if "2b" in model: - for name, param in transformer.named_parameters(): - if name != "pos_embedding": - param.data = param.data.to(torch.float8_e4m3fn) - elif "I2V" in model: - for name, param in transformer.named_parameters(): - if "patch_embed" not in name: - param.data = param.data.to(torch.float8_e4m3fn) - else: - #transformer.to(torch.float8_e4m3fn) - for name, param in transformer.named_parameters(): - if "lora" not in name: - param.data = param.data.to(torch.float8_e4m3fn) + for name, param in transformer.named_parameters(): + params_to_keep = {"patch_embed", "lora", "pos_embedding"} + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear