diff --git a/nodes.py b/nodes.py index fa88214..e1b90ba 100644 --- a/nodes.py +++ b/nodes.py @@ -482,7 +482,7 @@ class CogVideoXFunSampler: pipe = pipeline["pipe"] dtype = pipeline["dtype"] - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=device) mm.soft_empty_cache() @@ -594,7 +594,7 @@ class CogVideoXFunVid2VidSampler: pipe = pipeline["pipe"] dtype = pipeline["dtype"] - pipe.enable_model_cpu_offload() + pipe.enable_model_cpu_offload(device=device) mm.soft_empty_cache()