diff --git a/nodes.py b/nodes.py index e97237d..e33c956 100644 --- a/nodes.py +++ b/nodes.py @@ -194,7 +194,9 @@ class CogVideoImageEncode: vae.enable_slicing() else: vae.disable_slicing() - vae.to(device) + + if not pipeline["cpu_offloading"]: + vae.to(device) input_image = image.clone() * 2.0 - 1.0 input_image = input_image.to(vae.dtype).to(device) @@ -226,8 +228,8 @@ class CogVideoImageEncode: # Concatenate all the chunks along the temporal dimension final_latents = torch.cat(latents_list, dim=1) print("final latents: ", final_latents.shape) - - vae.to(offload_device) + if not pipeline["cpu_offloading"]: + vae.to(offload_device) return ({"samples": final_latents}, )