Add FlashVDM vae

...and it earned that name, whoa!

https://github.com/Tencent/FlashVDM/
This commit is contained in:
kijai 2025-03-19 13:49:07 +02:00
parent 1ecc8e3195
commit 280ab59834
12 changed files with 1517 additions and 16 deletions

View File

@ -25,4 +25,4 @@
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
from .hunyuan3ddit import Hunyuan3DDiT
from .vae import ShapeVAE
from .autoencoders import ShapeVAE

View File

@ -0,0 +1,20 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from .attention_blocks import CrossAttentionDecoder
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
FlashVDMTopMCrossAttentionProcessor
from .model import ShapeVAE, VectsetVAE
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder

View File

@ -0,0 +1,493 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from .attention_processors import CrossAttentionProcessor
from ...utils import logger
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
if os.environ.get('USE_SAGEATTN', '0') == '1':
try:
from sageattention import sageattn
except ImportError:
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
scaled_dot_product_attention = sageattn
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
], here f_i is the frequency.
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
Args:
num_freqs (int): the number of frequencies, default is 6;
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
input_dim (int): the input dimension, default is 3;
include_input (bool): include the input tensor or not, default is True.
Attributes:
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
otherwise, it is input_dim * num_freqs * 2.
"""
def __init__(self,
num_freqs: int = 6,
logspace: bool = True,
input_dim: int = 3,
include_input: bool = True,
include_pi: bool = True) -> None:
"""The initialization"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
num_freqs,
dtype=torch.float32
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (num_freqs - 1),
num_freqs,
dtype=torch.float32
)
if include_pi:
frequencies *= torch.pi
self.register_buffer("frequencies", frequencies, persistent=False)
self.include_input = include_input
self.num_freqs = num_freqs
self.out_dim = self.get_dims(input_dim)
def get_dims(self, input_dim):
temp = 1 if self.include_input or self.num_freqs == 0 else 0
out_dim = input_dim * (self.num_freqs * 2 + temp)
return out_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward process.
Args:
x: tensor of shape [..., dim]
Returns:
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
where temp is 1 if include_input is True and 0 otherwise.
"""
if self.num_freqs > 0:
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)
else:
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def extra_repr(self):
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
class MLP(nn.Module):
def __init__(
self, *,
width: int,
expand_ratio: int = 4,
output_width: int = None,
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * expand_ratio)
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
self.gelu = nn.GELU()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_data: Optional[int] = None,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
width: int,
heads: int,
qkv_bias: bool = True,
n_data: Optional[int] = None,
data_width: Optional[int] = None,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
kv_cache: bool = False,
):
super().__init__()
self.n_data = n_data
self.width = width
self.heads = heads
self.data_width = width if data_width is None else data_width
self.c_q = nn.Linear(width, width, bias=qkv_bias)
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadCrossAttention(
heads=heads,
n_data=n_data,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.kv_cache = kv_cache
self.data = None
def forward(self, x, data):
x = self.c_q(x)
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logger.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)
x = self.attention(x, data)
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
n_data: Optional[int] = None,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
data_width: Optional[int] = None,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False
):
super().__init__()
if data_width is None:
data_width = width
self.attn = MultiheadCrossAttention(
n_data=n_data,
width=width,
heads=heads,
data_width=data_width,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
def forward(self, x: torch.Tensor, data: torch.Tensor):
x = x + self.attn(self.ln_1(x), self.ln_2(data))
x = x + self.mlp(self.ln_3(x))
return x
class QKVMultiheadAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_ctx: int,
width=None,
qk_norm=False,
norm_layer=nn.LayerNorm
):
super().__init__()
self.heads = heads
self.n_ctx = n_ctx
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.heads // 3
qkv = qkv.view(bs, n_ctx, self.heads, -1)
q, k, v = torch.split(qkv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadAttention(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(
heads=heads,
n_ctx=n_ctx,
width=width,
norm_layer=norm_layer,
qk_norm=qk_norm
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.drop_path(self.c_proj(x))
return x
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
heads: int,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
def forward(self, x: torch.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
layers: int,
heads: int,
qkv_bias: bool = True,
norm_layer=nn.LayerNorm,
qk_norm: bool = False,
drop_path_rate: float = 0.0
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x
class CrossAttentionDecoder(nn.Module):
def __init__(
self,
*,
num_latents: int,
out_channels: int,
fourier_embedder: FourierEmbedder,
width: int,
heads: int,
mlp_expand_ratio: int = 4,
downsample_ratio: int = 1,
enable_ln_post: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary"
):
super().__init__()
self.enable_ln_post = enable_ln_post
self.fourier_embedder = fourier_embedder
self.downsample_ratio = downsample_ratio
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = nn.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
n_data=num_latents,
width=width,
mlp_expand_ratio=mlp_expand_ratio,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
if self.enable_ln_post:
self.ln_post = nn.LayerNorm(width)
self.output_proj = nn.Linear(width, out_channels)
self.label_type = label_type
self.count = 0
def set_cross_attention_processor(self, processor):
self.cross_attn_decoder.attn.attention.attn_processor = processor
def set_default_cross_attention_processor(self):
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
def forward(self, queries=None, query_embeddings=None, latents=None):
if query_embeddings is None:
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
self.count += query_embeddings.shape[1]
if self.downsample_ratio != 1:
latents = self.latents_proj(latents)
x = self.cross_attn_decoder(query_embeddings, latents)
if self.enable_ln_post:
x = self.ln_post(x)
occ = self.output_proj(x)
return occ

View File

@ -0,0 +1,96 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import torch
import torch.nn.functional as F
scaled_dot_product_attention = F.scaled_dot_product_attention
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
try:
from sageattention import sageattn
except ImportError:
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
scaled_dot_product_attention = sageattn
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = scaled_dot_product_attention(q, k, v)
return out
class FlashVDMCrossAttentionProcessor:
def __init__(self, topk=None):
self.topk = topk
def __call__(self, attn, q, k, v):
if k.shape[-2] == 3072:
topk = 1024
elif k.shape[-2] == 512:
topk = 256
else:
topk = k.shape[-2] // 3
if self.topk is True:
q1 = q[:, :, ::100, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
out = scaled_dot_product_attention(q, k0, v0)
elif self.topk is False:
out = scaled_dot_product_attention(q, k, v)
else:
idx, counts = self.topk
start = 0
outs = []
for grid_coord, count in zip(idx, counts):
end = start + count
q_chunk = q[:, :, start:end, :]
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
out = scaled_dot_product_attention(q_chunk, k0, v0)
outs.append(out)
start += count
out = torch.cat(outs, dim=-2)
self.topk = False
return out
def select_topkv(self, q_chunk, k, v, topk):
q1 = q_chunk[:, :, ::50, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
return k0, v0
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
def select_topkv(self, q_chunk, k, v, topk):
q1 = q_chunk[:, :, ::30, :]
sim = q1 @ k.transpose(-1, -2)
# sim = sim.to(torch.float32)
sim = sim.softmax(-1)
sim = torch.mean(sim, 1)
activated_token = torch.where(sim > 1e-6)[2]
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=index)
k0 = torch.gather(k, dim=-2, index=index)
return k0, v0

View File

@ -0,0 +1,189 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import torch
import torch.nn as nn
import yaml
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
from ...utils import logger, synchronize_timer, smart_load_model
class VectsetVAE(nn.Module):
@classmethod
@synchronize_timer('VectsetVAE Model Loading')
def from_single_file(
cls,
ckpt_path,
config_path,
device='cuda',
dtype=torch.float16,
use_safetensors=None,
**kwargs,
):
# load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# load ckpt
if use_safetensors:
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Model file {ckpt_path} not found")
logger.info(f"Loading model from {ckpt_path}")
if use_safetensors:
import safetensors.torch
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
else:
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
model_kwargs = config['params']
model_kwargs.update(kwargs)
model = cls(**model_kwargs)
model.load_state_dict(ckpt)
model.to(device=device, dtype=dtype)
return model
@classmethod
def from_pretrained(
cls,
model_path,
device='cuda',
dtype=torch.float16,
use_safetensors=True,
variant='fp16',
subfolder='hunyuan3d-vae-v2-0',
**kwargs,
):
config_path, ckpt_path = smart_load_model(
model_path,
subfolder=subfolder,
use_safetensors=use_safetensors,
variant=variant
)
return cls.from_single_file(
ckpt_path,
config_path,
device=device,
dtype=dtype,
use_safetensors=use_safetensors,
**kwargs
)
def __init__(
self,
volume_decoder=None,
surface_extractor=None
):
super().__init__()
if volume_decoder is None:
volume_decoder = VanillaVolumeDecoder()
if surface_extractor is None:
surface_extractor = MCSurfaceExtractor()
self.volume_decoder = volume_decoder
self.surface_extractor = surface_extractor
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
with synchronize_timer('Volume decoding'):
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
with synchronize_timer('Surface extraction'):
outputs = self.surface_extractor(grid_logits, **kwargs)
return outputs
def enable_flashvdm_decoder(
self,
enabled: bool = True,
adaptive_kv_selection=True,
topk_mode='mean',
mc_algo='dmc',
):
if enabled:
if adaptive_kv_selection:
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
else:
self.volume_decoder = HierarchicalVolumeDecoding()
if mc_algo not in SurfaceExtractors.keys():
raise ValueError(f'Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}')
self.surface_extractor = SurfaceExtractors[mc_algo]()
else:
self.volume_decoder = VanillaVolumeDecoder()
self.surface_extractor = MCSurfaceExtractor()
class ShapeVAE(VectsetVAE):
def __init__(
self,
*,
num_latents: int,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.post_kl = nn.Linear(embed_dim, width)
self.transformer = Transformer(
n_ctx=num_latents,
width=width,
layers=num_decoder_layers,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
drop_path_rate=drop_path_rate
)
self.geo_decoder = CrossAttentionDecoder(
fourier_embedder=self.fourier_embedder,
out_channels=1,
num_latents=num_latents,
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
downsample_ratio=geo_decoder_downsample_ratio,
enable_ln_post=self.geo_decoder_ln_post,
width=width // geo_decoder_downsample_ratio,
heads=heads // geo_decoder_downsample_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
label_type=label_type,
)
self.scale_factor = scale_factor
self.latent_shape = (num_latents, embed_dim)
def forward(self, latents):
latents = self.post_kl(latents)
latents = self.transformer(latents)
return latents

View File

@ -0,0 +1,100 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Union, Tuple, List
import numpy as np
import torch
from skimage import measure
class Latent2MeshOutput:
def __init__(self, mesh_v=None, mesh_f=None):
self.mesh_v = mesh_v
self.mesh_f = mesh_f
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(dim=0)[0]
vert_max = vertices.max(dim=0)[0]
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
class SurfaceExtractor:
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
return grid_size, bbox_min, bbox_size
def run(self, *args, **kwargs):
return NotImplementedError
def __call__(self, grid_logits, **kwargs):
outputs = []
for i in range(grid_logits.shape[0]):
try:
vertices, faces = self.run(grid_logits[i], **kwargs)
vertices = vertices.astype(np.float32)
faces = np.ascontiguousarray(faces)
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
except Exception:
import traceback
traceback.print_exc()
outputs.append(None)
return outputs
class MCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
vertices, faces, normals, _ = measure.marching_cubes(
grid_logit.cpu().numpy(),
mc_level,
method="lewiner"
)
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
vertices = vertices / grid_size * bbox_size + bbox_min
return vertices, faces
class DMCSurfaceExtractor(SurfaceExtractor):
def run(self, grid_logit, *, octree_resolution, **kwargs):
device = grid_logit.device
if not hasattr(self, 'dmc'):
try:
from diso import DiffDMC
except:
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
self.dmc = DiffDMC(dtype=torch.float32).to(device)
sdf = -grid_logit / octree_resolution
sdf = sdf.to(torch.float32).contiguous()
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
verts = center_vertices(verts)
vertices = verts.detach().cpu().numpy()
faces = faces.detach().cpu().numpy()[:, ::-1]
return vertices, faces
SurfaceExtractors = {
'mc': MCSurfaceExtractor,
'dmc': DMCSurfaceExtractor,
}

View File

@ -0,0 +1,435 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
from typing import Union, Tuple, List, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat
from tqdm import tqdm
from .attention_blocks import CrossAttentionDecoder
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
from ...utils import logger
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
device = input_tensor.device
D = input_tensor.shape[0]
signed_val = 0.0
# 添加偏移并处理无效值
val = input_tensor + alpha
valid_mask = val > -9000 # 假设-9000是无效值
# 改进的邻居获取函数(保持维度一致)
def get_neighbor(t, shift, axis):
"""根据指定轴进行位移并保持维度一致"""
if shift == 0:
return t.clone()
# 确定填充轴(输入为[D, D, D]对应z,y,x轴
pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前x后y前y后z前z后]
# 根据轴类型设置填充
if axis == 0: # x轴最后一个维度
pad_idx = 0 if shift > 0 else 1
pad_dims[pad_idx] = abs(shift)
elif axis == 1: # y轴中间维度
pad_idx = 2 if shift > 0 else 3
pad_dims[pad_idx] = abs(shift)
elif axis == 2: # z轴第一个维度
pad_idx = 4 if shift > 0 else 5
pad_dims[pad_idx] = abs(shift)
# 执行填充添加batch和channel维度适配F.pad
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
# 构建动态切片索引
slice_dims = [slice(None)] * 3 # 初始化为全切片
if axis == 0: # x轴dim=2
if shift > 0:
slice_dims[0] = slice(shift, None)
else:
slice_dims[0] = slice(None, shift)
elif axis == 1: # y轴dim=1
if shift > 0:
slice_dims[1] = slice(shift, None)
else:
slice_dims[1] = slice(None, shift)
elif axis == 2: # z轴dim=0
if shift > 0:
slice_dims[2] = slice(shift, None)
else:
slice_dims[2] = slice(None, shift)
# 应用切片并恢复维度
padded = padded.squeeze(0).squeeze(0)
sliced = padded[slice_dims]
return sliced
# 获取各方向邻居(确保维度一致)
left = get_neighbor(val, 1, axis=0) # x方向
right = get_neighbor(val, -1, axis=0)
back = get_neighbor(val, 1, axis=1) # y方向
front = get_neighbor(val, -1, axis=1)
down = get_neighbor(val, 1, axis=2) # z方向
up = get_neighbor(val, -1, axis=2)
# 处理边界无效值使用where保持维度一致
def safe_where(neighbor):
return torch.where(neighbor > -9000, neighbor, val)
left = safe_where(left)
right = safe_where(right)
back = safe_where(back)
front = safe_where(front)
down = safe_where(down)
up = safe_where(up)
# 计算符号一致性转换为float32确保精度
sign = torch.sign(val.to(torch.float32))
neighbors_sign = torch.stack([
torch.sign(left.to(torch.float32)),
torch.sign(right.to(torch.float32)),
torch.sign(back.to(torch.float32)),
torch.sign(front.to(torch.float32)),
torch.sign(down.to(torch.float32)),
torch.sign(up.to(torch.float32))
], dim=0)
# 检查所有符号是否一致
same_sign = torch.all(neighbors_sign == sign, dim=0)
# 生成最终掩码
mask = (~same_sign).to(torch.int32)
return mask * valid_mask.to(torch.int32)
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
return xyz, grid_size, length
class VanillaVolumeDecoder:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
return grid_logits
class HierarchicalVolumeDecoding:
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
mc_level: float = 0.0,
octree_resolution: int = None,
min_resolution: int = 63,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
resolutions = []
if octree_resolution < min_resolution:
resolutions.append(octree_resolution)
while octree_resolution >= min_resolution:
resolutions.append(octree_resolution)
octree_resolution = octree_resolution // 2
resolutions.reverse()
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=resolutions[0],
indexing="ij"
)
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
grid_size = np.array(grid_size)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
# 2. latents to 3d volume
batch_logits = []
batch_size = latents.shape[0]
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
queries = xyz_samples[start: start + num_chunks, :]
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=batch_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
for octree_depth_now in resolutions[1:]:
grid_size = np.array([octree_depth_now + 1] * 3)
resolution = bbox_size / octree_depth_now
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
curr_points += grid_logits.squeeze(0).abs() < 0.95
if octree_depth_now == resolutions[-1]:
expand_num = 0
else:
expand_num = 1
for i in range(expand_num):
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
for i in range(2 - expand_num):
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
nidx = torch.where(next_index > 0)
next_points = torch.stack(nidx, dim=1)
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
batch_logits = []
for start in tqdm(range(0, next_points.shape[0], num_chunks),
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
queries = next_points[start: start + num_chunks, :]
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
next_logits[nidx] = grid_logits[0, ..., 0]
grid_logits = next_logits.unsqueeze(0)
grid_logits[grid_logits == -10000.] = float('nan')
return grid_logits
class FlashVDMVolumeDecoding:
def __init__(self, topk_mode='mean'):
if topk_mode not in ['mean', 'merge']:
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
if topk_mode == 'mean':
self.processor = FlashVDMCrossAttentionProcessor()
else:
self.processor = FlashVDMTopMCrossAttentionProcessor()
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: CrossAttentionDecoder,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
mc_level: float = 0.0,
octree_resolution: int = None,
min_resolution: int = 63,
mini_grid_num: int = 4,
enable_pbar: bool = True,
**kwargs,
):
processor = self.processor
geo_decoder.set_cross_attention_processor(processor)
device = latents.device
dtype = latents.dtype
resolutions = []
if octree_resolution < min_resolution:
resolutions.append(octree_resolution)
while octree_resolution >= min_resolution:
resolutions.append(octree_resolution)
octree_resolution = octree_resolution // 2
resolutions.reverse()
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
for i, resolution in enumerate(resolutions[1:]):
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min = np.array(bounds[0:3])
bbox_max = np.array(bounds[3:6])
bbox_size = bbox_max - bbox_min
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=resolutions[0],
indexing="ij"
)
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
grid_size = np.array(grid_size)
# 2. latents to 3d volume
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
batch_size = latents.shape[0]
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
xyz_samples = xyz_samples.view(
mini_grid_num, mini_grid_size,
mini_grid_num, mini_grid_size,
mini_grid_num, mini_grid_size, 3
).permute(
0, 2, 4, 1, 3, 5, 6
).reshape(
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
)
batch_logits = []
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
queries = xyz_samples[start: start + num_batchs, :]
batch = queries.shape[0]
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
processor.topk = True
logits = geo_decoder(queries=queries, latents=batch_latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=0).reshape(
mini_grid_num, mini_grid_num, mini_grid_num,
mini_grid_size, mini_grid_size,
mini_grid_size
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
(batch_size, grid_size[0], grid_size[1], grid_size[2])
)
for octree_depth_now in resolutions[1:]:
grid_size = np.array([octree_depth_now + 1] * 3)
resolution = bbox_size / octree_depth_now
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
curr_points += grid_logits.squeeze(0).abs() < 0.95
if octree_depth_now == resolutions[-1]:
expand_num = 0
else:
expand_num = 1
for i in range(expand_num):
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
for i in range(2 - expand_num):
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
nidx = torch.where(next_index > 0)
next_points = torch.stack(nidx, dim=1)
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
torch.tensor(bbox_min, dtype=torch.float32, device=device))
query_grid_num = 6
min_val = next_points.min(axis=0).values
max_val = next_points.max(axis=0).values
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
index = torch.floor(vol_queries_index).long()
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
index = index.sort()
next_points = next_points[index.indices].unsqueeze(0).contiguous()
unique_values = torch.unique(index.values, return_counts=True)
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
input_grid = [[], []]
logits_grid_list = []
start_num = 0
sum_num = 0
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
if sum_num + count < num_chunks or sum_num == 0:
sum_num += count
input_grid[0].append(grid_index)
input_grid[1].append(count)
else:
processor.topk = input_grid
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
start_num = start_num + sum_num
logits_grid_list.append(logits_grid)
input_grid = [[grid_index], [count]]
sum_num = count
if sum_num > 0:
processor.topk = input_grid
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
logits_grid_list.append(logits_grid)
logits_grid = torch.cat(logits_grid_list, dim=1)
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
next_logits[nidx] = grid_logits
grid_logits = next_logits.unsqueeze(0)
grid_logits[grid_logits == -10000.] = float('nan')
return grid_logits

View File

View File

@ -204,7 +204,7 @@ class Hunyuan3DDiTPipeline:
config['model']['params']['guidance_embed'] = True
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
config['model']['params']['attention_mode'] = attention_mode
config['vae']['params']['attention_mode'] = attention_mode
#config['vae']['params']['attention_mode'] = attention_mode
if cublas_ops:
config['vae']['params']['cublas_ops'] = True

View File

@ -1,13 +1,3 @@
# Open Source Model Licensed under the Apache License Version 2.0
# and Other Licenses of the Third-Party Components therein:
# The below Model in this distribution may have been modified by THL A29 Limited
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# The below software and/or models in this distribution may have been
# modified by THL A29 Limited ("Tencent Modifications").
# All Tencent Modifications are Copyright (C) THL A29 Limited.
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
@ -25,11 +15,12 @@
import tempfile
import os
from typing import Union
import torch
import numpy as np
import pymeshlab
import trimesh
from .models.vae import Latent2MeshOutput
from .models.autoencoders import Latent2MeshOutput
import folder_paths
@ -44,6 +35,9 @@ def load_mesh(path):
def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
if max_facenum > mesh.current_mesh().face_number():
return mesh
mesh.apply_filter(
"meshing_decimation_quadric_edge_collapse",
targetfacenum=max_facenum,
@ -287,3 +281,47 @@ class DegenerateFaceRemover:
mesh = export_mesh(mesh, ms)
return mesh
def mesh_normalize(mesh):
"""
Normalize mesh vertices to sphere
"""
scale_factor = 1.2
vtx_pos = np.asarray(mesh.vertices)
max_bb = (vtx_pos - 0).max(0)[0]
min_bb = (vtx_pos - 0).min(0)[0]
center = (max_bb + min_bb) / 2
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
mesh.vertices = vtx_pos
return mesh
class MeshSimplifier:
def __init__(self, executable: str = None):
if executable is None:
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
self.executable = executable
def __call__(
self,
mesh: Union[trimesh.Trimesh],
) -> Union[trimesh.Trimesh]:
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
mesh.export(temp_input.name)
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
ms = trimesh.load(temp_output.name, process=False)
if isinstance(ms, trimesh.Scene):
combined_mesh = trimesh.Trimesh()
for geom in ms.geometry.values():
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
ms = combined_mesh
ms = mesh_normalize(ms)
return ms

123
hy3dgen/shapegen/utils.py Normal file
View File

@ -0,0 +1,123 @@
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import logging
import os
from functools import wraps
import torch
def get_logger(name):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger
logger = get_logger('hy3dgen.shapgen')
class synchronize_timer:
""" Synchronized timer to count the inference time of `nn.Module.forward`.
Supports both context manager and decorator usage.
Example as context manager:
```python
with synchronize_timer('name') as t:
run()
```
Example as decorator:
```python
@synchronize_timer('Export to trimesh')
def export_to_trimesh(mesh_output):
pass
```
"""
def __init__(self, name=None):
self.name = name
def __enter__(self):
"""Context manager entry: start timing."""
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
self.start.record()
return lambda: self.time
def __exit__(self, exc_type, exc_value, exc_tb):
"""Context manager exit: stop timing and log results."""
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
self.end.record()
torch.cuda.synchronize()
self.time = self.start.elapsed_time(self.end)
if self.name is not None:
logger.info(f'{self.name} takes {self.time} ms')
def __call__(self, func):
"""Decorator: wrap the function to time its execution."""
@wraps(func)
def wrapper(*args, **kwargs):
with self:
result = func(*args, **kwargs)
return result
return wrapper
def smart_load_model(
model_path,
subfolder,
use_safetensors,
variant,
):
original_model_path = model_path
# try local path
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
logger.info(f'Try to load model from local path: {model_path}')
if not os.path.exists(model_path):
logger.info('Model path not exists, try to download from huggingface')
try:
import huggingface_hub
# download from huggingface
path = huggingface_hub.snapshot_download(repo_id=original_model_path)
model_path = os.path.join(path, subfolder)
except ImportError:
logger.warning(
"You need to install HuggingFace Hub to load models from the hub."
)
raise RuntimeError(f"Model path {model_path} not found")
except Exception as e:
raise e
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path {original_model_path} not found")
extension = 'ckpt' if not use_safetensors else 'safetensors'
variant = '' if variant is None else f'.{variant}'
ckpt_name = f'model{variant}.{extension}'
config_path = os.path.join(model_path, 'config.yaml')
ckpt_path = os.path.join(model_path, ckpt_name)
return config_path, ckpt_path

View File

@ -1194,8 +1194,12 @@ class Hy3DVAEDecode:
"octree_resolution": ("INT", {"default": 384, "min": 8, "max": 4096, "step": 8}),
"num_chunks": ("INT", {"default": 8000, "min": 1, "max": 10000000, "step": 1, "tooltip": "Number of chunks to process at once, higher values use more memory, but make the process faster"}),
"mc_level": ("FLOAT", {"default": 0, "min": -1.0, "max": 1.0, "step": 0.0001}),
"mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}),
#"mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}),
"mc_algo": (["mc", "dmc"], {"default": "mc"}),
},
"optional": {
"enable_flash_vdm": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("TRIMESH",)
@ -1203,11 +1207,14 @@ class Hy3DVAEDecode:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo):
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo, enable_flash_vdm=True):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
vae.to(device)
vae.enable_flashvdm_decoder(enabled=enable_flash_vdm)
latents = 1. / vae.scale_factor * latents
latents = vae(latents)