Support multiview model

This commit is contained in:
kijai 2025-03-18 18:12:36 +02:00
parent 6090a9051f
commit 7ff9ad9ea0
5 changed files with 255 additions and 42 deletions

View File

@ -22,6 +22,7 @@
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
@ -31,6 +32,26 @@ from transformers import (
Dinov2Model,
Dinov2Config,
)
from ....utils import log
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1)
class ImageEncoder(nn.Module):
@ -68,36 +89,94 @@ class ImageEncoder(nn.Module):
]
)
def forward(self, image, mask=None, value_range=(-1, 1)):
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
#MV
self.view_num = 4
pos = np.arange(self.view_num, dtype=np.float32)
view_embedding = torch.from_numpy(
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
self.view_embed = view_embedding.unsqueeze(0)
image = image.to(self.model.device, dtype=self.model.dtype)
self.view2idx = {
'front': 0,
'left': 1,
'back': 2,
'right': 3
}
if mask is not None:
mask = mask.to(image)
image = image * mask
supported_sizes = [518, 530]
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed:
print(f'Image shape {image.shape} not supported. Resizing to 518x518')
inputs = self.transform(image)
def forward(self, image, mask=None, value_range=(-1, 1), view_dict=None):
if view_dict is None:
self.view_num = 1
if value_range is not None:
low, high = value_range
image = (image - low) / (high - low)
image = image.to(self.model.device, dtype=self.model.dtype)
if mask is not None:
mask = mask.to(image)
image = image * mask
supported_sizes = [518, 530]
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed:
log.info(f'Image shape {image.shape} not supported. Resizing to 518x518')
inputs = self.transform(image)
else:
inputs = image
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state
if not self.use_cls_token:
last_hidden_state = last_hidden_state[:, 1:, :]
return last_hidden_state
else:
inputs = image
outputs = self.model(inputs)
images = []
view_indexes = []
supported_sizes = [518, 530]
last_hidden_state = outputs.last_hidden_state
if not self.use_cls_token:
last_hidden_state = last_hidden_state[:, 1:, :]
for view_tag, view_image in view_dict.items():
if view_image is not None:
if view_image.shape[1] == 4:
log.info('received image with alpha channel, masking out background...')
rgb = view_image[:, :3, :, :]
alpha = view_image[:, 3:4, :, :]
view_image = rgb * alpha
view_image = view_image.to(self.model.device, dtype=self.model.dtype)
if (view_image.shape[2] not in supported_sizes or view_image.shape[3] not in supported_sizes):
log.info(f'view_image shape {view_image.shape} not supported. Resizing to 518x518')
view_image = self.transform(view_image)
images.append(view_image)
view_indexes.append(self.view2idx[view_tag])
return last_hidden_state
image_tensors = torch.cat(images, 0)
outputs = self.model(image_tensors)
self.view_num = len(images)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = last_hidden_state.view(
self.view_num,
-1,
last_hidden_state.shape[-1]
)
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
if self.view_num != 4:
view_embeddings = []
for idx in range(len(view_indexes)):
view_embeddings.append(self.view_embed[:, idx, ...])
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
last_hidden_state = last_hidden_state + view_embedding
last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1])
return last_hidden_state.unsqueeze(0)
def unconditional_embedding(self, batch_size):
device = next(self.model.parameters()).device
dtype = next(self.model.parameters()).dtype
zero = torch.zeros(
batch_size,
self.num_patches,
self.num_patches * self.view_num,
self.model.config.hidden_size,
device=device,
dtype=dtype,
@ -162,9 +241,9 @@ class SingleImageEncoder(nn.Module):
super().__init__()
self.main_image_encoder = build_image_encoder(main_image_encoder)
def forward(self, image, mask=None):
def forward(self, image=None, mask=None, view_dict=None):
outputs = {
'main': self.main_image_encoder(image, mask=mask),
'main': self.main_image_encoder(image=image, mask=mask, view_dict=view_dict),
}
return outputs

View File

@ -71,6 +71,15 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding
class GELU(nn.Module):
def __init__(self, approximate='tanh'):
super().__init__()
self.approximate = approximate
def forward(self, x: Tensor) -> Tensor:
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
@ -178,7 +187,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
@ -189,7 +198,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
@ -260,7 +269,7 @@ class SingleStreamBlock(nn.Module):
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.mlp_act = GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:

View File

@ -311,10 +311,10 @@ class Hunyuan3DDiTPipeline:
self.model.to(dtype=dtype)
self.conditioner.to(dtype=dtype)
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, view_dict=None):
self.conditioner.to(self.main_device)
bsz = image.shape[0]
cond = self.conditioner(image=image, mask=mask)
bsz = 1
cond = self.conditioner(image=image, mask=mask, view_dict=view_dict)
if do_classifier_free_guidance:
un_cond = self.conditioner.unconditional_embedding(bsz)
@ -575,6 +575,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
# num_chunks=8000,
# output_type: Optional[str] = "trimesh",
enable_pbar=True,
view_dict=None,
**kwargs,
) -> List[List[trimesh.Trimesh]]:
callback = kwargs.pop("callback", None)
@ -594,8 +595,9 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
mask=mask,
do_classifier_free_guidance=do_classifier_free_guidance,
dual_guidance=False,
view_dict=view_dict
)
batch_size = image.shape[0]
batch_size = 1
# 5. Prepare timesteps
# NOTE: this is slightly different from common usage, we start from 0.

View File

@ -87,7 +87,7 @@ class ImageProcessorV2:
interpolation=cv2.INTER_AREA)
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
# bg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
mask = result[..., 3:].astype(np.float32) / 255
result = result[..., :3] * mask + bg * (1 - mask)
@ -96,15 +96,13 @@ class ImageProcessorV2:
mask = mask.clip(0, 255).astype(np.uint8)
return result, mask
def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs):
if self.border_ratio is not None:
border_ratio = self.border_ratio
print(f"Using border_ratio from init: {border_ratio}")
def load_image(self, image, border_ratio=0.15, to_tensor=True):
if isinstance(image, str):
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
image, mask = self.recenter(image, border_ratio=border_ratio)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image.Image):
image = image.convert("RGBA")
image = np.asarray(image)
image, mask = self.recenter(image, border_ratio=border_ratio)
@ -115,13 +113,64 @@ class ImageProcessorV2:
if to_tensor:
image = array_to_tensor(image)
mask = array_to_tensor(mask)
if return_mask:
return image, mask
return image
return image, mask
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
if self.border_ratio is not None:
border_ratio = self.border_ratio
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
outputs = {
'image': image,
'mask': mask
}
return outputs
class MVImageProcessorV2(ImageProcessorV2):
"""
view order: front, front clockwise 90, back, front clockwise 270
"""
return_view_idx = True
def __init__(self, size=512, border_ratio=None):
super().__init__(size, border_ratio)
self.view2idx = {
'front': 0,
'left': 1,
'back': 2,
'right': 3
}
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
if self.border_ratio is not None:
border_ratio = self.border_ratio
images = []
masks = []
view_idxs = []
for idx, (view_tag, image) in enumerate(image_dict.items()):
view_idxs.append(self.view2idx[view_tag])
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
images.append(image)
masks.append(mask)
zipped_lists = zip(view_idxs, images, masks)
sorted_zipped_lists = sorted(zipped_lists)
view_idxs, images, masks = zip(*sorted_zipped_lists)
image = torch.cat(images, 0).unsqueeze(0)
mask = torch.cat(masks, 0).unsqueeze(0)
outputs = {
'image': image,
'mask': mask,
'view_idxs': view_idxs
}
return outputs
IMAGE_PROCESSORS = {
"v2": ImageProcessorV2,
'mv_v2': MVImageProcessorV2,
}
DEFAULT_IMAGEPROCESSOR = 'v2'

View File

@ -84,7 +84,7 @@ class Hy3DTorchCompileSettings:
RETURN_TYPES = ("HY3DCOMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "loadmodel"
CATEGORY = "HunyuanVideoWrapper"
CATEGORY = "Hunyuan3DWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
@ -1034,7 +1034,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None):
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
mm.unload_all_models()
mm.soft_empty_cache()
@ -1046,7 +1046,7 @@ class Hy3DGenerateMesh:
image = image * 2 - 1
if mask is not None:
mask = mask.unsqueeze(0).to(device)
mask = mask.unsqueeze(1).repeat(1, 3, 1, 1).to(device)
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
@ -1074,6 +1074,78 @@ class Hy3DGenerateMesh:
return (latents, )
class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pipeline": ("HY3DMODEL",),
"guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}),
"steps": ("INT", {"default": 30, "min": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
"optional": {
"front": ("IMAGE", ),
"left": ("IMAGE", ),
"right": ("IMAGE", ),
"back": ("IMAGE", ),
}
}
RETURN_TYPES = ("HY3DLATENT",)
RETURN_NAMES = ("latents",)
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
mm.unload_all_models()
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipeline.to(device)
if front is not None:
front = front.clone().permute(0, 3, 1, 2).to(device)
if back is not None:
back = back.clone().permute(0, 3, 1, 2).to(device)
if left is not None:
left = left.clone().permute(0, 3, 1, 2).to(device)
if right is not None:
right = right.clone().permute(0, 3, 1, 2).to(device)
view_dict = {
'front': front,
'left': left,
'right': right,
'back': back
}
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
latents = pipeline(
image=None,
mask=mask,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=torch.manual_seed(seed),
view_dict=view_dict)
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.to(offload_device)
return (latents, )
class Hy3DVAEDecode:
@classmethod
def INPUT_TYPES(s):
@ -1596,6 +1668,7 @@ class Hy3DNvdiffrastRenderer:
NODE_CLASS_MAPPINGS = {
"Hy3DModelLoader": Hy3DModelLoader,
"Hy3DGenerateMesh": Hy3DGenerateMesh,
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
"Hy3DExportMesh": Hy3DExportMesh,
"DownloadAndLoadHy3DDelightModel": DownloadAndLoadHy3DDelightModel,
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
@ -1628,6 +1701,7 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader",
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
"Hy3DExportMesh": "Hy3DExportMesh",
"DownloadAndLoadHy3DDelightModel": "(Down)Load Hy3D DelightModel",
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",