mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
Support multiview model
This commit is contained in:
parent
6090a9051f
commit
7ff9ad9ea0
@ -22,6 +22,7 @@
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
@ -31,6 +32,26 @@ from transformers import (
|
||||
Dinov2Model,
|
||||
Dinov2Config,
|
||||
)
|
||||
from ....utils import log
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1)
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
@ -68,36 +89,94 @@ class ImageEncoder(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, image, mask=None, value_range=(-1, 1)):
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
#MV
|
||||
self.view_num = 4
|
||||
pos = np.arange(self.view_num, dtype=np.float32)
|
||||
view_embedding = torch.from_numpy(
|
||||
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
|
||||
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
|
||||
self.view_embed = view_embedding.unsqueeze(0)
|
||||
|
||||
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||
self.view2idx = {
|
||||
'front': 0,
|
||||
'left': 1,
|
||||
'back': 2,
|
||||
'right': 3
|
||||
}
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.to(image)
|
||||
image = image * mask
|
||||
supported_sizes = [518, 530]
|
||||
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed:
|
||||
print(f'Image shape {image.shape} not supported. Resizing to 518x518')
|
||||
inputs = self.transform(image)
|
||||
def forward(self, image, mask=None, value_range=(-1, 1), view_dict=None):
|
||||
if view_dict is None:
|
||||
self.view_num = 1
|
||||
if value_range is not None:
|
||||
low, high = value_range
|
||||
image = (image - low) / (high - low)
|
||||
|
||||
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.to(image)
|
||||
image = image * mask
|
||||
supported_sizes = [518, 530]
|
||||
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed:
|
||||
log.info(f'Image shape {image.shape} not supported. Resizing to 518x518')
|
||||
inputs = self.transform(image)
|
||||
else:
|
||||
inputs = image
|
||||
outputs = self.model(inputs)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
if not self.use_cls_token:
|
||||
last_hidden_state = last_hidden_state[:, 1:, :]
|
||||
|
||||
return last_hidden_state
|
||||
else:
|
||||
inputs = image
|
||||
outputs = self.model(inputs)
|
||||
images = []
|
||||
view_indexes = []
|
||||
supported_sizes = [518, 530]
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
if not self.use_cls_token:
|
||||
last_hidden_state = last_hidden_state[:, 1:, :]
|
||||
for view_tag, view_image in view_dict.items():
|
||||
if view_image is not None:
|
||||
if view_image.shape[1] == 4:
|
||||
log.info('received image with alpha channel, masking out background...')
|
||||
rgb = view_image[:, :3, :, :]
|
||||
alpha = view_image[:, 3:4, :, :]
|
||||
view_image = rgb * alpha
|
||||
view_image = view_image.to(self.model.device, dtype=self.model.dtype)
|
||||
if (view_image.shape[2] not in supported_sizes or view_image.shape[3] not in supported_sizes):
|
||||
log.info(f'view_image shape {view_image.shape} not supported. Resizing to 518x518')
|
||||
view_image = self.transform(view_image)
|
||||
images.append(view_image)
|
||||
view_indexes.append(self.view2idx[view_tag])
|
||||
|
||||
return last_hidden_state
|
||||
image_tensors = torch.cat(images, 0)
|
||||
|
||||
outputs = self.model(image_tensors)
|
||||
|
||||
self.view_num = len(images)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = last_hidden_state.view(
|
||||
self.view_num,
|
||||
-1,
|
||||
last_hidden_state.shape[-1]
|
||||
)
|
||||
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||
if self.view_num != 4:
|
||||
view_embeddings = []
|
||||
for idx in range(len(view_indexes)):
|
||||
view_embeddings.append(self.view_embed[:, idx, ...])
|
||||
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||
|
||||
last_hidden_state = last_hidden_state + view_embedding
|
||||
last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1])
|
||||
return last_hidden_state.unsqueeze(0)
|
||||
|
||||
def unconditional_embedding(self, batch_size):
|
||||
device = next(self.model.parameters()).device
|
||||
dtype = next(self.model.parameters()).dtype
|
||||
zero = torch.zeros(
|
||||
batch_size,
|
||||
self.num_patches,
|
||||
self.num_patches * self.view_num,
|
||||
self.model.config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
@ -162,9 +241,9 @@ class SingleImageEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||
|
||||
def forward(self, image, mask=None):
|
||||
def forward(self, image=None, mask=None, view_dict=None):
|
||||
outputs = {
|
||||
'main': self.main_image_encoder(image, mask=mask),
|
||||
'main': self.main_image_encoder(image=image, mask=mask, view_dict=view_dict),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
@ -71,6 +71,15 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, approximate='tanh'):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
@ -178,7 +187,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
@ -189,7 +198,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
@ -260,7 +269,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.mlp_act = GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
|
||||
@ -311,10 +311,10 @@ class Hunyuan3DDiTPipeline:
|
||||
self.model.to(dtype=dtype)
|
||||
self.conditioner.to(dtype=dtype)
|
||||
|
||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, view_dict=None):
|
||||
self.conditioner.to(self.main_device)
|
||||
bsz = image.shape[0]
|
||||
cond = self.conditioner(image=image, mask=mask)
|
||||
bsz = 1
|
||||
cond = self.conditioner(image=image, mask=mask, view_dict=view_dict)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
un_cond = self.conditioner.unconditional_embedding(bsz)
|
||||
@ -575,6 +575,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
# num_chunks=8000,
|
||||
# output_type: Optional[str] = "trimesh",
|
||||
enable_pbar=True,
|
||||
view_dict=None,
|
||||
**kwargs,
|
||||
) -> List[List[trimesh.Trimesh]]:
|
||||
callback = kwargs.pop("callback", None)
|
||||
@ -594,8 +595,9 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
mask=mask,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dual_guidance=False,
|
||||
view_dict=view_dict
|
||||
)
|
||||
batch_size = image.shape[0]
|
||||
batch_size = 1
|
||||
|
||||
# 5. Prepare timesteps
|
||||
# NOTE: this is slightly different from common usage, we start from 0.
|
||||
|
||||
@ -87,7 +87,7 @@ class ImageProcessorV2:
|
||||
interpolation=cv2.INTER_AREA)
|
||||
|
||||
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
|
||||
# bg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
|
||||
|
||||
mask = result[..., 3:].astype(np.float32) / 255
|
||||
result = result[..., :3] * mask + bg * (1 - mask)
|
||||
|
||||
@ -96,15 +96,13 @@ class ImageProcessorV2:
|
||||
mask = mask.clip(0, 255).astype(np.uint8)
|
||||
return result, mask
|
||||
|
||||
def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs):
|
||||
if self.border_ratio is not None:
|
||||
border_ratio = self.border_ratio
|
||||
print(f"Using border_ratio from init: {border_ratio}")
|
||||
def load_image(self, image, border_ratio=0.15, to_tensor=True):
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
||||
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
elif isinstance(image, Image.Image):
|
||||
image = image.convert("RGBA")
|
||||
image = np.asarray(image)
|
||||
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||
|
||||
@ -115,13 +113,64 @@ class ImageProcessorV2:
|
||||
if to_tensor:
|
||||
image = array_to_tensor(image)
|
||||
mask = array_to_tensor(mask)
|
||||
if return_mask:
|
||||
return image, mask
|
||||
return image
|
||||
return image, mask
|
||||
|
||||
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||
if self.border_ratio is not None:
|
||||
border_ratio = self.border_ratio
|
||||
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||
outputs = {
|
||||
'image': image,
|
||||
'mask': mask
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
class MVImageProcessorV2(ImageProcessorV2):
|
||||
"""
|
||||
view order: front, front clockwise 90, back, front clockwise 270
|
||||
"""
|
||||
return_view_idx = True
|
||||
|
||||
def __init__(self, size=512, border_ratio=None):
|
||||
super().__init__(size, border_ratio)
|
||||
self.view2idx = {
|
||||
'front': 0,
|
||||
'left': 1,
|
||||
'back': 2,
|
||||
'right': 3
|
||||
}
|
||||
|
||||
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||
if self.border_ratio is not None:
|
||||
border_ratio = self.border_ratio
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
view_idxs = []
|
||||
for idx, (view_tag, image) in enumerate(image_dict.items()):
|
||||
view_idxs.append(self.view2idx[view_tag])
|
||||
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||
images.append(image)
|
||||
masks.append(mask)
|
||||
|
||||
zipped_lists = zip(view_idxs, images, masks)
|
||||
sorted_zipped_lists = sorted(zipped_lists)
|
||||
view_idxs, images, masks = zip(*sorted_zipped_lists)
|
||||
|
||||
image = torch.cat(images, 0).unsqueeze(0)
|
||||
mask = torch.cat(masks, 0).unsqueeze(0)
|
||||
outputs = {
|
||||
'image': image,
|
||||
'mask': mask,
|
||||
'view_idxs': view_idxs
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
IMAGE_PROCESSORS = {
|
||||
"v2": ImageProcessorV2,
|
||||
'mv_v2': MVImageProcessorV2,
|
||||
}
|
||||
|
||||
DEFAULT_IMAGEPROCESSOR = 'v2'
|
||||
80
nodes.py
80
nodes.py
@ -84,7 +84,7 @@ class Hy3DTorchCompileSettings:
|
||||
RETURN_TYPES = ("HY3DCOMPILEARGS",)
|
||||
RETURN_NAMES = ("torch_compile_args",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "HunyuanVideoWrapper"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||
|
||||
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
|
||||
@ -1034,7 +1034,7 @@ class Hy3DGenerateMesh:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None):
|
||||
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
|
||||
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
@ -1046,7 +1046,7 @@ class Hy3DGenerateMesh:
|
||||
image = image * 2 - 1
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(0).to(device)
|
||||
mask = mask.unsqueeze(1).repeat(1, 3, 1, 1).to(device)
|
||||
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
|
||||
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
|
||||
|
||||
@ -1074,6 +1074,78 @@ class Hy3DGenerateMesh:
|
||||
|
||||
return (latents, )
|
||||
|
||||
class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipeline": ("HY3DMODEL",),
|
||||
"guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}),
|
||||
"steps": ("INT", {"default": 30, "min": 1}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
},
|
||||
"optional": {
|
||||
"front": ("IMAGE", ),
|
||||
"left": ("IMAGE", ),
|
||||
"right": ("IMAGE", ),
|
||||
"back": ("IMAGE", ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("HY3DLATENT",)
|
||||
RETURN_NAMES = ("latents",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
|
||||
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
pipeline.to(device)
|
||||
|
||||
if front is not None:
|
||||
front = front.clone().permute(0, 3, 1, 2).to(device)
|
||||
if back is not None:
|
||||
back = back.clone().permute(0, 3, 1, 2).to(device)
|
||||
if left is not None:
|
||||
left = left.clone().permute(0, 3, 1, 2).to(device)
|
||||
if right is not None:
|
||||
right = right.clone().permute(0, 3, 1, 2).to(device)
|
||||
|
||||
view_dict = {
|
||||
'front': front,
|
||||
'left': left,
|
||||
'right': right,
|
||||
'back': back
|
||||
}
|
||||
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
latents = pipeline(
|
||||
image=None,
|
||||
mask=mask,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.manual_seed(seed),
|
||||
view_dict=view_dict)
|
||||
|
||||
print_memory(device)
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
pipeline.to(offload_device)
|
||||
|
||||
return (latents, )
|
||||
|
||||
class Hy3DVAEDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1596,6 +1668,7 @@ class Hy3DNvdiffrastRenderer:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Hy3DModelLoader": Hy3DModelLoader,
|
||||
"Hy3DGenerateMesh": Hy3DGenerateMesh,
|
||||
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
|
||||
"Hy3DExportMesh": Hy3DExportMesh,
|
||||
"DownloadAndLoadHy3DDelightModel": DownloadAndLoadHy3DDelightModel,
|
||||
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
|
||||
@ -1628,6 +1701,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Hy3DModelLoader": "Hy3DModelLoader",
|
||||
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
|
||||
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
|
||||
"Hy3DExportMesh": "Hy3DExportMesh",
|
||||
"DownloadAndLoadHy3DDelightModel": "(Down)Load Hy3D DelightModel",
|
||||
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user