Add multiview example

This commit is contained in:
kijai 2025-03-18 18:48:52 +02:00
parent 7ff9ad9ea0
commit fdd685ddf2
2 changed files with 3765 additions and 17 deletions

File diff suppressed because it is too large Load Diff

View File

@ -222,7 +222,7 @@ class Hy3DDelightImage:
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
print("image in shape", image.shape)
if scheduler is not None: if scheduler is not None:
if not hasattr(self, "default_scheduler"): if not hasattr(self, "default_scheduler"):
self.default_scheduler = delight_pipe.scheduler self.default_scheduler = delight_pipe.scheduler
@ -234,9 +234,12 @@ class Hy3DDelightImage:
image = image.permute(0, 3, 1, 2).to(device) image = image.permute(0, 3, 1, 2).to(device)
image = common_upscale(image, width, height, "lanczos", "disabled") image = common_upscale(image, width, height, "lanczos", "disabled")
image = delight_pipe(
images_list = []
for img in image:
out = delight_pipe(
prompt="", prompt="",
image=image, image=img,
generator=torch.manual_seed(seed), generator=torch.manual_seed(seed),
height=height, height=height,
width=width, width=width,
@ -246,8 +249,9 @@ class Hy3DDelightImage:
output_type="pt", output_type="pt",
).images[0] ).images[0]
images_list.append(out)
out_tensor = image.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() out_tensor = torch.stack(images_list).permute(0, 2, 3, 1).cpu().float()
return (out_tensor, ) return (out_tensor, )
@ -1092,8 +1096,8 @@ class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
} }
} }
RETURN_TYPES = ("HY3DLATENT",) RETURN_TYPES = ("HY3DLATENT", "IMAGE", "MASK",)
RETURN_NAMES = ("latents",) RETURN_NAMES = ("latents", "image", "mask")
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
@ -1142,9 +1146,23 @@ class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
except: except:
pass pass
images = []
masks = []
for view_tag, view_image in view_dict.items():
if view_image is not None:
if view_image.shape[1] == 4:
rgb = view_image[:, :3, :, :]
alpha = view_image[:, 3:4, :, :]
mask = alpha
masks.append(mask)
images.append(rgb)
image_tensors = torch.cat(images, 0).permute(0, 2, 3, 1).cpu().float()
mask_tensors = torch.cat(masks, 0).squeeze(1).cpu().float()
pipeline.to(offload_device) pipeline.to(offload_device)
return (latents, ) return (latents, image_tensors, mask_tensors)
class Hy3DVAEDecode: class Hy3DVAEDecode:
@classmethod @classmethod