fix encoding with cpu offloading

This commit is contained in:
kijai 2024-08-30 22:13:01 +03:00
parent 248428ccf8
commit c1efd95a03

View File

@ -194,7 +194,9 @@ class CogVideoImageEncode:
vae.enable_slicing()
else:
vae.disable_slicing()
vae.to(device)
if not pipeline["cpu_offloading"]:
vae.to(device)
input_image = image.clone() * 2.0 - 1.0
input_image = input_image.to(vae.dtype).to(device)
@ -226,8 +228,8 @@ class CogVideoImageEncode:
# Concatenate all the chunks along the temporal dimension
final_latents = torch.cat(latents_list, dim=1)
print("final latents: ", final_latents.shape)
vae.to(offload_device)
if not pipeline["cpu_offloading"]:
vae.to(offload_device)
return ({"samples": final_latents}, )