diff --git a/nodes.py b/nodes.py index dea1dbc..129ea8a 100644 --- a/nodes.py +++ b/nodes.py @@ -393,7 +393,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer = CogVideoXTransformer3DModelFun.from_config(transformer_config) elif "I2V" in model: transformer_config["in_channels"] = 32 - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + if pab_config is not None: + transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config) + else: + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) else: transformer_config["in_channels"] = 16 if pab_config is not None: