decode multiview outputs one by one for less of a VRAM spike

This commit is contained in:
kijai 2025-01-29 02:08:08 +02:00
parent 448597bc89
commit bb441cbbed

View File

@ -530,9 +530,12 @@ class HunyuanPaintPipeline(StableDiffusionPipeline):
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image_list = []
for img in latents:
image = self.vae.decode(img / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
image_list.append(image)
image = torch.cat(image_list, dim=0)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents