diff --git a/nodes.py b/nodes.py index 62f827b..403fd5f 100644 --- a/nodes.py +++ b/nodes.py @@ -76,7 +76,7 @@ class DownloadAndLoadCogVideoModel: torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True - pipe.transformer.to(memory_format=torch.channels_last) + pipe.transformer.to(device).to(memory_format=torch.channels_last) pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) pipeline = {