diff --git a/model_loading.py b/model_loading.py index 1f2a288..c190b54 100644 --- a/model_loading.py +++ b/model_loading.py @@ -425,8 +425,7 @@ class DownloadAndLoadCogVideoGGUFModel: }, "optional": { "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), - #"lora": ("COGLORA", {"default": None}), - "compile": (["disabled","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), + "compile_args":("COMPILEARGS", ), "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), } } @@ -437,7 +436,7 @@ class DownloadAndLoadCogVideoGGUFModel: CATEGORY = "CogVideoWrapper" def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, - block_edit=None, compile="disabled", attention_mode="sdpa"): + block_edit=None, compile_args=None, attention_mode="sdpa"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -497,7 +496,7 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["in_channels"] = 16 transformer = CogVideoXTransformer3DModel.from_config(transformer_config) - + cast_dtype = vae_dtype params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"} if "2b" in model: cast_dtype = torch.float16 @@ -524,10 +523,12 @@ class DownloadAndLoadCogVideoGGUFModel: from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, vae_dtype, params_to_keep=params_to_keep) - if compile == "torch": - # compilation + if compile_args is not None: + torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] for i, block in enumerate(transformer.transformer_blocks): - transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") + if "CogVideoXBlock" in str(block): + transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) + with open(scheduler_path) as f: scheduler_config = json.load(f)