From 7ff9ad9ea064a432709cdbd4216d62376c2263c1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 18 Mar 2025 18:12:36 +0200 Subject: [PATCH] Support multiview model --- hy3dgen/shapegen/models/conditioner.py | 125 +++++++++++++++++++----- hy3dgen/shapegen/models/hunyuan3ddit.py | 15 ++- hy3dgen/shapegen/pipelines.py | 10 +- hy3dgen/shapegen/preprocessors.py | 67 +++++++++++-- nodes.py | 80 ++++++++++++++- 5 files changed, 255 insertions(+), 42 deletions(-) diff --git a/hy3dgen/shapegen/models/conditioner.py b/hy3dgen/shapegen/models/conditioner.py index 431336a..11ffc51 100755 --- a/hy3dgen/shapegen/models/conditioner.py +++ b/hy3dgen/shapegen/models/conditioner.py @@ -22,6 +22,7 @@ # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. +import numpy as np import torch import torch.nn as nn from torchvision import transforms @@ -31,6 +32,26 @@ from transformers import ( Dinov2Model, Dinov2Config, ) +from ....utils import log + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + return np.concatenate([emb_sin, emb_cos], axis=1) class ImageEncoder(nn.Module): @@ -68,36 +89,94 @@ class ImageEncoder(nn.Module): ] ) - def forward(self, image, mask=None, value_range=(-1, 1)): - if value_range is not None: - low, high = value_range - image = (image - low) / (high - low) - - image = image.to(self.model.device, dtype=self.model.dtype) + #MV + self.view_num = 4 + pos = np.arange(self.view_num, dtype=np.float32) + view_embedding = torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float() + view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1) + self.view_embed = view_embedding.unsqueeze(0) - if mask is not None: - mask = mask.to(image) - image = image * mask - supported_sizes = [518, 530] - if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed: - print(f'Image shape {image.shape} not supported. Resizing to 518x518') - inputs = self.transform(image) + self.view2idx = { + 'front': 0, + 'left': 1, + 'back': 2, + 'right': 3 + } + + def forward(self, image, mask=None, value_range=(-1, 1), view_dict=None): + if view_dict is None: + self.view_num = 1 + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.model.device, dtype=self.model.dtype) + + if mask is not None: + mask = mask.to(image) + image = image * mask + supported_sizes = [518, 530] + if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed: + log.info(f'Image shape {image.shape} not supported. Resizing to 518x518') + inputs = self.transform(image) + else: + inputs = image + outputs = self.model(inputs) + + last_hidden_state = outputs.last_hidden_state + if not self.use_cls_token: + last_hidden_state = last_hidden_state[:, 1:, :] + + return last_hidden_state else: - inputs = image - outputs = self.model(inputs) + images = [] + view_indexes = [] + supported_sizes = [518, 530] - last_hidden_state = outputs.last_hidden_state - if not self.use_cls_token: - last_hidden_state = last_hidden_state[:, 1:, :] + for view_tag, view_image in view_dict.items(): + if view_image is not None: + if view_image.shape[1] == 4: + log.info('received image with alpha channel, masking out background...') + rgb = view_image[:, :3, :, :] + alpha = view_image[:, 3:4, :, :] + view_image = rgb * alpha + view_image = view_image.to(self.model.device, dtype=self.model.dtype) + if (view_image.shape[2] not in supported_sizes or view_image.shape[3] not in supported_sizes): + log.info(f'view_image shape {view_image.shape} not supported. Resizing to 518x518') + view_image = self.transform(view_image) + images.append(view_image) + view_indexes.append(self.view2idx[view_tag]) - return last_hidden_state + image_tensors = torch.cat(images, 0) + + outputs = self.model(image_tensors) + + self.view_num = len(images) + + last_hidden_state = outputs.last_hidden_state + last_hidden_state = last_hidden_state.view( + self.view_num, + -1, + last_hidden_state.shape[-1] + ) + view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device) + if self.view_num != 4: + view_embeddings = [] + for idx in range(len(view_indexes)): + view_embeddings.append(self.view_embed[:, idx, ...]) + view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device) + + last_hidden_state = last_hidden_state + view_embedding + last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1]) + return last_hidden_state.unsqueeze(0) def unconditional_embedding(self, batch_size): device = next(self.model.parameters()).device dtype = next(self.model.parameters()).dtype zero = torch.zeros( batch_size, - self.num_patches, + self.num_patches * self.view_num, self.model.config.hidden_size, device=device, dtype=dtype, @@ -162,9 +241,9 @@ class SingleImageEncoder(nn.Module): super().__init__() self.main_image_encoder = build_image_encoder(main_image_encoder) - def forward(self, image, mask=None): + def forward(self, image=None, mask=None, view_dict=None): outputs = { - 'main': self.main_image_encoder(image, mask=mask), + 'main': self.main_image_encoder(image=image, mask=mask, view_dict=view_dict), } return outputs @@ -172,4 +251,4 @@ class SingleImageEncoder(nn.Module): outputs = { 'main': self.main_image_encoder.unconditional_embedding(batch_size), } - return outputs + return outputs \ No newline at end of file diff --git a/hy3dgen/shapegen/models/hunyuan3ddit.py b/hy3dgen/shapegen/models/hunyuan3ddit.py index a128659..8cf03e3 100755 --- a/hy3dgen/shapegen/models/hunyuan3ddit.py +++ b/hy3dgen/shapegen/models/hunyuan3ddit.py @@ -71,6 +71,15 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding +class GELU(nn.Module): + def __init__(self, approximate='tanh'): + super().__init__() + self.approximate = approximate + + def forward(self, x: Tensor) -> Tensor: + return nn.functional.gelu(x.contiguous(), approximate=self.approximate) + + class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() @@ -178,7 +187,7 @@ class DoubleStreamBlock(nn.Module): self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), - nn.GELU(approximate="tanh"), + GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) @@ -189,7 +198,7 @@ class DoubleStreamBlock(nn.Module): self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), - nn.GELU(approximate="tanh"), + GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) @@ -260,7 +269,7 @@ class SingleStreamBlock(nn.Module): self.hidden_size = hidden_size self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.mlp_act = nn.GELU(approximate="tanh") + self.mlp_act = GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: diff --git a/hy3dgen/shapegen/pipelines.py b/hy3dgen/shapegen/pipelines.py index 757e655..31716aa 100755 --- a/hy3dgen/shapegen/pipelines.py +++ b/hy3dgen/shapegen/pipelines.py @@ -311,10 +311,10 @@ class Hunyuan3DDiTPipeline: self.model.to(dtype=dtype) self.conditioner.to(dtype=dtype) - def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance): + def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, view_dict=None): self.conditioner.to(self.main_device) - bsz = image.shape[0] - cond = self.conditioner(image=image, mask=mask) + bsz = 1 + cond = self.conditioner(image=image, mask=mask, view_dict=view_dict) if do_classifier_free_guidance: un_cond = self.conditioner.unconditional_embedding(bsz) @@ -575,6 +575,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): # num_chunks=8000, # output_type: Optional[str] = "trimesh", enable_pbar=True, + view_dict=None, **kwargs, ) -> List[List[trimesh.Trimesh]]: callback = kwargs.pop("callback", None) @@ -594,8 +595,9 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): mask=mask, do_classifier_free_guidance=do_classifier_free_guidance, dual_guidance=False, + view_dict=view_dict ) - batch_size = image.shape[0] + batch_size = 1 # 5. Prepare timesteps # NOTE: this is slightly different from common usage, we start from 0. diff --git a/hy3dgen/shapegen/preprocessors.py b/hy3dgen/shapegen/preprocessors.py index 2bdaff2..065a604 100755 --- a/hy3dgen/shapegen/preprocessors.py +++ b/hy3dgen/shapegen/preprocessors.py @@ -87,7 +87,7 @@ class ImageProcessorV2: interpolation=cv2.INTER_AREA) bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255 - # bg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255 + mask = result[..., 3:].astype(np.float32) / 255 result = result[..., :3] * mask + bg * (1 - mask) @@ -96,15 +96,13 @@ class ImageProcessorV2: mask = mask.clip(0, 255).astype(np.uint8) return result, mask - def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs): - if self.border_ratio is not None: - border_ratio = self.border_ratio - print(f"Using border_ratio from init: {border_ratio}") + def load_image(self, image, border_ratio=0.15, to_tensor=True): if isinstance(image, str): image = cv2.imread(image, cv2.IMREAD_UNCHANGED) image, mask = self.recenter(image, border_ratio=border_ratio) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) elif isinstance(image, Image.Image): + image = image.convert("RGBA") image = np.asarray(image) image, mask = self.recenter(image, border_ratio=border_ratio) @@ -115,13 +113,64 @@ class ImageProcessorV2: if to_tensor: image = array_to_tensor(image) mask = array_to_tensor(mask) - if return_mask: - return image, mask - return image + return image, mask + + def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs): + if self.border_ratio is not None: + border_ratio = self.border_ratio + image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor) + outputs = { + 'image': image, + 'mask': mask + } + return outputs + + +class MVImageProcessorV2(ImageProcessorV2): + """ + view order: front, front clockwise 90, back, front clockwise 270 + """ + return_view_idx = True + + def __init__(self, size=512, border_ratio=None): + super().__init__(size, border_ratio) + self.view2idx = { + 'front': 0, + 'left': 1, + 'back': 2, + 'right': 3 + } + + def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs): + if self.border_ratio is not None: + border_ratio = self.border_ratio + + images = [] + masks = [] + view_idxs = [] + for idx, (view_tag, image) in enumerate(image_dict.items()): + view_idxs.append(self.view2idx[view_tag]) + image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor) + images.append(image) + masks.append(mask) + + zipped_lists = zip(view_idxs, images, masks) + sorted_zipped_lists = sorted(zipped_lists) + view_idxs, images, masks = zip(*sorted_zipped_lists) + + image = torch.cat(images, 0).unsqueeze(0) + mask = torch.cat(masks, 0).unsqueeze(0) + outputs = { + 'image': image, + 'mask': mask, + 'view_idxs': view_idxs + } + return outputs IMAGE_PROCESSORS = { "v2": ImageProcessorV2, + 'mv_v2': MVImageProcessorV2, } -DEFAULT_IMAGEPROCESSOR = 'v2' +DEFAULT_IMAGEPROCESSOR = 'v2' \ No newline at end of file diff --git a/nodes.py b/nodes.py index 9d1a52a..4e10b19 100644 --- a/nodes.py +++ b/nodes.py @@ -84,7 +84,7 @@ class Hy3DTorchCompileSettings: RETURN_TYPES = ("HY3DCOMPILEARGS",) RETURN_NAMES = ("torch_compile_args",) FUNCTION = "loadmodel" - CATEGORY = "HunyuanVideoWrapper" + CATEGORY = "Hunyuan3DWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae): @@ -1034,7 +1034,7 @@ class Hy3DGenerateMesh: FUNCTION = "process" CATEGORY = "Hunyuan3DWrapper" - def process(self, pipeline, image, steps, guidance_scale, seed, mask=None): + def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None): mm.unload_all_models() mm.soft_empty_cache() @@ -1046,7 +1046,7 @@ class Hy3DGenerateMesh: image = image * 2 - 1 if mask is not None: - mask = mask.unsqueeze(0).to(device) + mask = mask.unsqueeze(1).repeat(1, 3, 1, 1).to(device) if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]: mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest') @@ -1074,6 +1074,78 @@ class Hy3DGenerateMesh: return (latents, ) +class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "pipeline": ("HY3DMODEL",), + "guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}), + "steps": ("INT", {"default": 30, "min": 1}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + "optional": { + "front": ("IMAGE", ), + "left": ("IMAGE", ), + "right": ("IMAGE", ), + "back": ("IMAGE", ), + } + } + + RETURN_TYPES = ("HY3DLATENT",) + RETURN_NAMES = ("latents",) + FUNCTION = "process" + CATEGORY = "Hunyuan3DWrapper" + + def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None): + + mm.unload_all_models() + mm.soft_empty_cache() + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + pipeline.to(device) + + if front is not None: + front = front.clone().permute(0, 3, 1, 2).to(device) + if back is not None: + back = back.clone().permute(0, 3, 1, 2).to(device) + if left is not None: + left = left.clone().permute(0, 3, 1, 2).to(device) + if right is not None: + right = right.clone().permute(0, 3, 1, 2).to(device) + + view_dict = { + 'front': front, + 'left': left, + 'right': right, + 'back': back + } + + try: + torch.cuda.reset_peak_memory_stats(device) + except: + pass + + latents = pipeline( + image=None, + mask=mask, + num_inference_steps=steps, + guidance_scale=guidance_scale, + generator=torch.manual_seed(seed), + view_dict=view_dict) + + print_memory(device) + try: + torch.cuda.reset_peak_memory_stats(device) + except: + pass + + pipeline.to(offload_device) + + return (latents, ) + class Hy3DVAEDecode: @classmethod def INPUT_TYPES(s): @@ -1596,6 +1668,7 @@ class Hy3DNvdiffrastRenderer: NODE_CLASS_MAPPINGS = { "Hy3DModelLoader": Hy3DModelLoader, "Hy3DGenerateMesh": Hy3DGenerateMesh, + "Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView, "Hy3DExportMesh": Hy3DExportMesh, "DownloadAndLoadHy3DDelightModel": DownloadAndLoadHy3DDelightModel, "DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel, @@ -1628,6 +1701,7 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "Hy3DModelLoader": "Hy3DModelLoader", "Hy3DGenerateMesh": "Hy3DGenerateMesh", + "Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView", "Hy3DExportMesh": "Hy3DExportMesh", "DownloadAndLoadHy3DDelightModel": "(Down)Load Hy3D DelightModel", "DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",