From bb441cbbeda30f65c2b7754b224db34fc73d0804 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 29 Jan 2025 02:08:08 +0200 Subject: [PATCH] decode multiview outputs one by one for less of a VRAM spike --- hy3dgen/texgen/hunyuanpaint/pipeline.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hy3dgen/texgen/hunyuanpaint/pipeline.py b/hy3dgen/texgen/hunyuanpaint/pipeline.py index e15950d..16920e6 100755 --- a/hy3dgen/texgen/hunyuanpaint/pipeline.py +++ b/hy3dgen/texgen/hunyuanpaint/pipeline.py @@ -530,9 +530,12 @@ class HunyuanPaintPipeline(StableDiffusionPipeline): callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] + image_list = [] + for img in latents: + image = self.vae.decode(img / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + image_list.append(image) + image = torch.cat(image_list, dim=0) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents