diff --git a/model_loading.py b/model_loading.py index e54f3e2..041625d 100644 --- a/model_loading.py +++ b/model_loading.py @@ -190,6 +190,8 @@ class DownloadAndLoadCogVideoModel: # compilation if compile == "torch": torch._dynamo.config.suppress_errors = True + torch._dynamo.config.cache_size_limit = 64 + pipe.transformer.to(memory_format=torch.channels_last) #pipe.transformer = torch.compile(pipe.transformer, mode="default", fullgraph=False, backend="inductor") for i, block in enumerate(pipe.transformer.transformer_blocks):