Support multiview model

This commit is contained in:
kijai 2025-03-18 18:12:36 +02:00
parent 6090a9051f
commit 7ff9ad9ea0
5 changed files with 255 additions and 42 deletions

View File

@ -22,6 +22,7 @@
# fine-tuning enabling code and other elements of the foregoing made publicly available # fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision import transforms from torchvision import transforms
@ -31,6 +32,26 @@ from transformers import (
Dinov2Model, Dinov2Model,
Dinov2Config, Dinov2Config,
) )
from ....utils import log
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1)
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
@ -68,36 +89,94 @@ class ImageEncoder(nn.Module):
] ]
) )
def forward(self, image, mask=None, value_range=(-1, 1)): #MV
if value_range is not None: self.view_num = 4
low, high = value_range pos = np.arange(self.view_num, dtype=np.float32)
image = (image - low) / (high - low) view_embedding = torch.from_numpy(
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
self.view_embed = view_embedding.unsqueeze(0)
image = image.to(self.model.device, dtype=self.model.dtype) self.view2idx = {
'front': 0,
'left': 1,
'back': 2,
'right': 3
}
if mask is not None: def forward(self, image, mask=None, value_range=(-1, 1), view_dict=None):
mask = mask.to(image) if view_dict is None:
image = image * mask self.view_num = 1
supported_sizes = [518, 530] if value_range is not None:
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed: low, high = value_range
print(f'Image shape {image.shape} not supported. Resizing to 518x518') image = (image - low) / (high - low)
inputs = self.transform(image)
image = image.to(self.model.device, dtype=self.model.dtype)
if mask is not None:
mask = mask.to(image)
image = image * mask
supported_sizes = [518, 530]
if (image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes) and not self.has_guidance_embed:
log.info(f'Image shape {image.shape} not supported. Resizing to 518x518')
inputs = self.transform(image)
else:
inputs = image
outputs = self.model(inputs)
last_hidden_state = outputs.last_hidden_state
if not self.use_cls_token:
last_hidden_state = last_hidden_state[:, 1:, :]
return last_hidden_state
else: else:
inputs = image images = []
outputs = self.model(inputs) view_indexes = []
supported_sizes = [518, 530]
last_hidden_state = outputs.last_hidden_state for view_tag, view_image in view_dict.items():
if not self.use_cls_token: if view_image is not None:
last_hidden_state = last_hidden_state[:, 1:, :] if view_image.shape[1] == 4:
log.info('received image with alpha channel, masking out background...')
rgb = view_image[:, :3, :, :]
alpha = view_image[:, 3:4, :, :]
view_image = rgb * alpha
view_image = view_image.to(self.model.device, dtype=self.model.dtype)
if (view_image.shape[2] not in supported_sizes or view_image.shape[3] not in supported_sizes):
log.info(f'view_image shape {view_image.shape} not supported. Resizing to 518x518')
view_image = self.transform(view_image)
images.append(view_image)
view_indexes.append(self.view2idx[view_tag])
return last_hidden_state image_tensors = torch.cat(images, 0)
outputs = self.model(image_tensors)
self.view_num = len(images)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = last_hidden_state.view(
self.view_num,
-1,
last_hidden_state.shape[-1]
)
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
if self.view_num != 4:
view_embeddings = []
for idx in range(len(view_indexes)):
view_embeddings.append(self.view_embed[:, idx, ...])
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
last_hidden_state = last_hidden_state + view_embedding
last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1])
return last_hidden_state.unsqueeze(0)
def unconditional_embedding(self, batch_size): def unconditional_embedding(self, batch_size):
device = next(self.model.parameters()).device device = next(self.model.parameters()).device
dtype = next(self.model.parameters()).dtype dtype = next(self.model.parameters()).dtype
zero = torch.zeros( zero = torch.zeros(
batch_size, batch_size,
self.num_patches, self.num_patches * self.view_num,
self.model.config.hidden_size, self.model.config.hidden_size,
device=device, device=device,
dtype=dtype, dtype=dtype,
@ -162,9 +241,9 @@ class SingleImageEncoder(nn.Module):
super().__init__() super().__init__()
self.main_image_encoder = build_image_encoder(main_image_encoder) self.main_image_encoder = build_image_encoder(main_image_encoder)
def forward(self, image, mask=None): def forward(self, image=None, mask=None, view_dict=None):
outputs = { outputs = {
'main': self.main_image_encoder(image, mask=mask), 'main': self.main_image_encoder(image=image, mask=mask, view_dict=view_dict),
} }
return outputs return outputs

View File

@ -71,6 +71,15 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding return embedding
class GELU(nn.Module):
def __init__(self, approximate='tanh'):
super().__init__()
self.approximate = approximate
def forward(self, x: Tensor) -> Tensor:
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int): def __init__(self, in_dim: int, hidden_dim: int):
super().__init__() super().__init__()
@ -178,7 +187,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential( self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"), GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
) )
@ -189,7 +198,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential( self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"), GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
) )
@ -260,7 +269,7 @@ class SingleStreamBlock(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False) self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:

View File

@ -311,10 +311,10 @@ class Hunyuan3DDiTPipeline:
self.model.to(dtype=dtype) self.model.to(dtype=dtype)
self.conditioner.to(dtype=dtype) self.conditioner.to(dtype=dtype)
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance): def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance, view_dict=None):
self.conditioner.to(self.main_device) self.conditioner.to(self.main_device)
bsz = image.shape[0] bsz = 1
cond = self.conditioner(image=image, mask=mask) cond = self.conditioner(image=image, mask=mask, view_dict=view_dict)
if do_classifier_free_guidance: if do_classifier_free_guidance:
un_cond = self.conditioner.unconditional_embedding(bsz) un_cond = self.conditioner.unconditional_embedding(bsz)
@ -575,6 +575,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
# num_chunks=8000, # num_chunks=8000,
# output_type: Optional[str] = "trimesh", # output_type: Optional[str] = "trimesh",
enable_pbar=True, enable_pbar=True,
view_dict=None,
**kwargs, **kwargs,
) -> List[List[trimesh.Trimesh]]: ) -> List[List[trimesh.Trimesh]]:
callback = kwargs.pop("callback", None) callback = kwargs.pop("callback", None)
@ -594,8 +595,9 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
mask=mask, mask=mask,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
dual_guidance=False, dual_guidance=False,
view_dict=view_dict
) )
batch_size = image.shape[0] batch_size = 1
# 5. Prepare timesteps # 5. Prepare timesteps
# NOTE: this is slightly different from common usage, we start from 0. # NOTE: this is slightly different from common usage, we start from 0.

View File

@ -87,7 +87,7 @@ class ImageProcessorV2:
interpolation=cv2.INTER_AREA) interpolation=cv2.INTER_AREA)
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255 bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
# bg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
mask = result[..., 3:].astype(np.float32) / 255 mask = result[..., 3:].astype(np.float32) / 255
result = result[..., :3] * mask + bg * (1 - mask) result = result[..., :3] * mask + bg * (1 - mask)
@ -96,15 +96,13 @@ class ImageProcessorV2:
mask = mask.clip(0, 255).astype(np.uint8) mask = mask.clip(0, 255).astype(np.uint8)
return result, mask return result, mask
def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs): def load_image(self, image, border_ratio=0.15, to_tensor=True):
if self.border_ratio is not None:
border_ratio = self.border_ratio
print(f"Using border_ratio from init: {border_ratio}")
if isinstance(image, str): if isinstance(image, str):
image = cv2.imread(image, cv2.IMREAD_UNCHANGED) image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
image, mask = self.recenter(image, border_ratio=border_ratio) image, mask = self.recenter(image, border_ratio=border_ratio)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image, Image.Image): elif isinstance(image, Image.Image):
image = image.convert("RGBA")
image = np.asarray(image) image = np.asarray(image)
image, mask = self.recenter(image, border_ratio=border_ratio) image, mask = self.recenter(image, border_ratio=border_ratio)
@ -115,13 +113,64 @@ class ImageProcessorV2:
if to_tensor: if to_tensor:
image = array_to_tensor(image) image = array_to_tensor(image)
mask = array_to_tensor(mask) mask = array_to_tensor(mask)
if return_mask: return image, mask
return image, mask
return image def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
if self.border_ratio is not None:
border_ratio = self.border_ratio
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
outputs = {
'image': image,
'mask': mask
}
return outputs
class MVImageProcessorV2(ImageProcessorV2):
"""
view order: front, front clockwise 90, back, front clockwise 270
"""
return_view_idx = True
def __init__(self, size=512, border_ratio=None):
super().__init__(size, border_ratio)
self.view2idx = {
'front': 0,
'left': 1,
'back': 2,
'right': 3
}
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
if self.border_ratio is not None:
border_ratio = self.border_ratio
images = []
masks = []
view_idxs = []
for idx, (view_tag, image) in enumerate(image_dict.items()):
view_idxs.append(self.view2idx[view_tag])
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
images.append(image)
masks.append(mask)
zipped_lists = zip(view_idxs, images, masks)
sorted_zipped_lists = sorted(zipped_lists)
view_idxs, images, masks = zip(*sorted_zipped_lists)
image = torch.cat(images, 0).unsqueeze(0)
mask = torch.cat(masks, 0).unsqueeze(0)
outputs = {
'image': image,
'mask': mask,
'view_idxs': view_idxs
}
return outputs
IMAGE_PROCESSORS = { IMAGE_PROCESSORS = {
"v2": ImageProcessorV2, "v2": ImageProcessorV2,
'mv_v2': MVImageProcessorV2,
} }
DEFAULT_IMAGEPROCESSOR = 'v2' DEFAULT_IMAGEPROCESSOR = 'v2'

View File

@ -84,7 +84,7 @@ class Hy3DTorchCompileSettings:
RETURN_TYPES = ("HY3DCOMPILEARGS",) RETURN_TYPES = ("HY3DCOMPILEARGS",)
RETURN_NAMES = ("torch_compile_args",) RETURN_NAMES = ("torch_compile_args",)
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "HunyuanVideoWrapper" CATEGORY = "Hunyuan3DWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae): def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer, compile_vae):
@ -1034,7 +1034,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None): def process(self, pipeline, image, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
mm.unload_all_models() mm.unload_all_models()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -1046,7 +1046,7 @@ class Hy3DGenerateMesh:
image = image * 2 - 1 image = image * 2 - 1
if mask is not None: if mask is not None:
mask = mask.unsqueeze(0).to(device) mask = mask.unsqueeze(1).repeat(1, 3, 1, 1).to(device)
if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]: if mask.shape[2] != image.shape[2] or mask.shape[3] != image.shape[3]:
mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest') mask = F.interpolate(mask, size=(image.shape[2], image.shape[3]), mode='nearest')
@ -1074,6 +1074,78 @@ class Hy3DGenerateMesh:
return (latents, ) return (latents, )
class Hy3DGenerateMeshMultiView(Hy3DGenerateMesh):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pipeline": ("HY3DMODEL",),
"guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}),
"steps": ("INT", {"default": 30, "min": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
"optional": {
"front": ("IMAGE", ),
"left": ("IMAGE", ),
"right": ("IMAGE", ),
"back": ("IMAGE", ),
}
}
RETURN_TYPES = ("HY3DLATENT",)
RETURN_NAMES = ("latents",)
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, steps, guidance_scale, seed, mask=None, front=None, back=None, left=None, right=None):
mm.unload_all_models()
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
pipeline.to(device)
if front is not None:
front = front.clone().permute(0, 3, 1, 2).to(device)
if back is not None:
back = back.clone().permute(0, 3, 1, 2).to(device)
if left is not None:
left = left.clone().permute(0, 3, 1, 2).to(device)
if right is not None:
right = right.clone().permute(0, 3, 1, 2).to(device)
view_dict = {
'front': front,
'left': left,
'right': right,
'back': back
}
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
latents = pipeline(
image=None,
mask=mask,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=torch.manual_seed(seed),
view_dict=view_dict)
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.to(offload_device)
return (latents, )
class Hy3DVAEDecode: class Hy3DVAEDecode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1596,6 +1668,7 @@ class Hy3DNvdiffrastRenderer:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Hy3DModelLoader": Hy3DModelLoader, "Hy3DModelLoader": Hy3DModelLoader,
"Hy3DGenerateMesh": Hy3DGenerateMesh, "Hy3DGenerateMesh": Hy3DGenerateMesh,
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
"Hy3DExportMesh": Hy3DExportMesh, "Hy3DExportMesh": Hy3DExportMesh,
"DownloadAndLoadHy3DDelightModel": DownloadAndLoadHy3DDelightModel, "DownloadAndLoadHy3DDelightModel": DownloadAndLoadHy3DDelightModel,
"DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel, "DownloadAndLoadHy3DPaintModel": DownloadAndLoadHy3DPaintModel,
@ -1628,6 +1701,7 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader", "Hy3DModelLoader": "Hy3DModelLoader",
"Hy3DGenerateMesh": "Hy3DGenerateMesh", "Hy3DGenerateMesh": "Hy3DGenerateMesh",
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
"Hy3DExportMesh": "Hy3DExportMesh", "Hy3DExportMesh": "Hy3DExportMesh",
"DownloadAndLoadHy3DDelightModel": "(Down)Load Hy3D DelightModel", "DownloadAndLoadHy3DDelightModel": "(Down)Load Hy3D DelightModel",
"DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel", "DownloadAndLoadHy3DPaintModel": "(Down)Load Hy3D PaintModel",