Add multiview example

This commit is contained in:
kijai 2025-03-18 18:48:52 +02:00
parent 7ff9ad9ea0
commit fdd685ddf2
2 changed files with 3765 additions and 17 deletions

File diff suppressed because it is too large Load Diff

View File

@ -222,7 +222,7 @@ class Hy3DDelightImage:
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
print("image in shape", image.shape)
if scheduler is not None: if scheduler is not None:
if not hasattr(self, "default_scheduler"): if not hasattr(self, "default_scheduler"):
self.default_scheduler = delight_pipe.scheduler self.default_scheduler = delight_pipe.scheduler
@ -233,21 +233,25 @@ class Hy3DDelightImage:
image = image.permute(0, 3, 1, 2).to(device) image = image.permute(0, 3, 1, 2).to(device)
image = common_upscale(image, width, height, "lanczos", "disabled") image = common_upscale(image, width, height, "lanczos", "disabled")
image = delight_pipe( images_list = []
prompt="", for img in image:
image=image, out = delight_pipe(
generator=torch.manual_seed(seed), prompt="",
height=height, image=img,
width=width, generator=torch.manual_seed(seed),
num_inference_steps=steps, height=height,
image_guidance_scale=cfg_image, width=width,
guidance_scale=1.0 if cfg_image == 1.0 else 1.01, #enable cfg for image, value doesn't matter as it do anything for text anyway num_inference_steps=steps,
output_type="pt", image_guidance_scale=cfg_image,
guidance_scale=1.0 if cfg_image == 1.0 else 1.01, #enable cfg for image, value doesn't matter as it do anything for text anyway
).images[0] output_type="pt",
).images[0]
images_list.append(out)
out_tensor = image.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() out_tensor = torch.stack(images_list).permute(0, 2, 3, 1).cpu().float()
return (out_tensor, ) return (out_tensor, )
@ -1092,8 +1096,8 @@ class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
} }
} }
RETURN_TYPES = ("HY3DLATENT",) RETURN_TYPES = ("HY3DLATENT", "IMAGE", "MASK",)
RETURN_NAMES = ("latents",) RETURN_NAMES = ("latents", "image", "mask")
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
@ -1141,10 +1145,24 @@ class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
torch.cuda.reset_peak_memory_stats(device) torch.cuda.reset_peak_memory_stats(device)
except: except:
pass pass
images = []
masks = []
for view_tag, view_image in view_dict.items():
if view_image is not None:
if view_image.shape[1] == 4:
rgb = view_image[:, :3, :, :]
alpha = view_image[:, 3:4, :, :]
mask = alpha
masks.append(mask)
images.append(rgb)
image_tensors = torch.cat(images, 0).permute(0, 2, 3, 1).cpu().float()
mask_tensors = torch.cat(masks, 0).squeeze(1).cpu().float()
pipeline.to(offload_device) pipeline.to(offload_device)
return (latents, ) return (latents, image_tensors, mask_tensors)
class Hy3DVAEDecode: class Hy3DVAEDecode:
@classmethod @classmethod