From 280ab598342f6b25648fbb13efe97a971474c8b6 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 19 Mar 2025 13:49:07 +0200 Subject: [PATCH] Add FlashVDM vae ...and it earned that name, whoa! https://github.com/Tencent/FlashVDM/ --- hy3dgen/shapegen/models/__init__.py | 2 +- .../shapegen/models/autoencoders/__init__.py | 20 + .../models/autoencoders/attention_blocks.py | 493 ++++++++++++++++++ .../autoencoders/attention_processors.py | 96 ++++ hy3dgen/shapegen/models/autoencoders/model.py | 189 +++++++ .../models/autoencoders/surface_extractors.py | 100 ++++ .../models/autoencoders/volume_decoders.py | 435 ++++++++++++++++ .../shapegen/models/{vae.py => vae_old.py} | 0 hy3dgen/shapegen/pipelines.py | 2 +- hy3dgen/shapegen/postprocessors.py | 62 ++- hy3dgen/shapegen/utils.py | 123 +++++ nodes.py | 11 +- 12 files changed, 1517 insertions(+), 16 deletions(-) create mode 100644 hy3dgen/shapegen/models/autoencoders/__init__.py create mode 100644 hy3dgen/shapegen/models/autoencoders/attention_blocks.py create mode 100644 hy3dgen/shapegen/models/autoencoders/attention_processors.py create mode 100644 hy3dgen/shapegen/models/autoencoders/model.py create mode 100644 hy3dgen/shapegen/models/autoencoders/surface_extractors.py create mode 100644 hy3dgen/shapegen/models/autoencoders/volume_decoders.py rename hy3dgen/shapegen/models/{vae.py => vae_old.py} (100%) mode change 100755 => 100644 create mode 100644 hy3dgen/shapegen/utils.py diff --git a/hy3dgen/shapegen/models/__init__.py b/hy3dgen/shapegen/models/__init__.py index 684b3e3..46bf894 100755 --- a/hy3dgen/shapegen/models/__init__.py +++ b/hy3dgen/shapegen/models/__init__.py @@ -25,4 +25,4 @@ from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder from .hunyuan3ddit import Hunyuan3DDiT -from .vae import ShapeVAE +from .autoencoders import ShapeVAE diff --git a/hy3dgen/shapegen/models/autoencoders/__init__.py b/hy3dgen/shapegen/models/autoencoders/__init__.py new file mode 100644 index 0000000..20bbf8d --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/__init__.py @@ -0,0 +1,20 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from .attention_blocks import CrossAttentionDecoder +from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \ + FlashVDMTopMCrossAttentionProcessor +from .model import ShapeVAE, VectsetVAE +from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput +from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder diff --git a/hy3dgen/shapegen/models/autoencoders/attention_blocks.py b/hy3dgen/shapegen/models/autoencoders/attention_blocks.py new file mode 100644 index 0000000..ab34eeb --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/attention_blocks.py @@ -0,0 +1,493 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + + +import os +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from .attention_processors import CrossAttentionProcessor +from ...utils import logger + +scaled_dot_product_attention = nn.functional.scaled_dot_product_attention + +if os.environ.get('USE_SAGEATTN', '0') == '1': + try: + from sageattention import sageattn + except ImportError: + raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') + scaled_dot_product_attention = sageattn + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * expand_ratio) + self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_data: Optional[int] = None, + width=None, + qk_norm=False, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.heads = heads + self.n_data = n_data + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + n_data=n_data, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logger.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + n_data: Optional[int] = None, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + n_ctx: int, + width=None, + qk_norm=False, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.heads = heads + self.n_ctx = n_ctx + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + n_ctx=n_ctx, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + n_ctx: int, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = nn.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + n_data=num_latents, + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = nn.LayerNorm(width) + self.output_proj = nn.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def set_cross_attention_processor(self, processor): + self.cross_attn_decoder.attn.attention.attn_processor = processor + + def set_default_cross_attention_processor(self): + self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ diff --git a/hy3dgen/shapegen/models/autoencoders/attention_processors.py b/hy3dgen/shapegen/models/autoencoders/attention_processors.py new file mode 100644 index 0000000..f7b232e --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/attention_processors.py @@ -0,0 +1,96 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os + +import torch +import torch.nn.functional as F + +scaled_dot_product_attention = F.scaled_dot_product_attention +if os.environ.get('CA_USE_SAGEATTN', '0') == '1': + try: + from sageattention import sageattn + except ImportError: + raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') + scaled_dot_product_attention = sageattn + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = scaled_dot_product_attention(q, k, v) + return out + + +class FlashVDMCrossAttentionProcessor: + def __init__(self, topk=None): + self.topk = topk + + def __call__(self, attn, q, k, v): + if k.shape[-2] == 3072: + topk = 1024 + elif k.shape[-2] == 512: + topk = 256 + else: + topk = k.shape[-2] // 3 + + if self.topk is True: + q1 = q[:, :, ::100, :] + sim = q1 @ k.transpose(-1, -2) + sim = torch.mean(sim, -2) + topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1) + topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=topk_ind) + k0 = torch.gather(k, dim=-2, index=topk_ind) + out = scaled_dot_product_attention(q, k0, v0) + elif self.topk is False: + out = scaled_dot_product_attention(q, k, v) + else: + idx, counts = self.topk + start = 0 + outs = [] + for grid_coord, count in zip(idx, counts): + end = start + count + q_chunk = q[:, :, start:end, :] + k0, v0 = self.select_topkv(q_chunk, k, v, topk) + out = scaled_dot_product_attention(q_chunk, k0, v0) + outs.append(out) + start += count + out = torch.cat(outs, dim=-2) + self.topk = False + return out + + def select_topkv(self, q_chunk, k, v, topk): + q1 = q_chunk[:, :, ::50, :] + sim = q1 @ k.transpose(-1, -2) + sim = torch.mean(sim, -2) + topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1) + topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=topk_ind) + k0 = torch.gather(k, dim=-2, index=topk_ind) + return k0, v0 + + +class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor): + def select_topkv(self, q_chunk, k, v, topk): + q1 = q_chunk[:, :, ::30, :] + sim = q1 @ k.transpose(-1, -2) + # sim = sim.to(torch.float32) + sim = sim.softmax(-1) + sim = torch.mean(sim, 1) + activated_token = torch.where(sim > 1e-6)[2] + index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1) + index = index.expand(-1, v.shape[1], -1, v.shape[-1]) + v0 = torch.gather(v, dim=-2, index=index) + k0 = torch.gather(k, dim=-2, index=index) + return k0, v0 diff --git a/hy3dgen/shapegen/models/autoencoders/model.py b/hy3dgen/shapegen/models/autoencoders/model.py new file mode 100644 index 0000000..76f78da --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/model.py @@ -0,0 +1,189 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os + +import torch +import torch.nn as nn +import yaml + +from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder +from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors +from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding +from ...utils import logger, synchronize_timer, smart_load_model + + +class VectsetVAE(nn.Module): + + @classmethod + @synchronize_timer('VectsetVAE Model Loading') + def from_single_file( + cls, + ckpt_path, + config_path, + device='cuda', + dtype=torch.float16, + use_safetensors=None, + **kwargs, + ): + # load config + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # load ckpt + if use_safetensors: + ckpt_path = ckpt_path.replace('.ckpt', '.safetensors') + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Model file {ckpt_path} not found") + + logger.info(f"Loading model from {ckpt_path}") + if use_safetensors: + import safetensors.torch + ckpt = safetensors.torch.load_file(ckpt_path, device='cpu') + else: + ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True) + + model_kwargs = config['params'] + model_kwargs.update(kwargs) + + model = cls(**model_kwargs) + model.load_state_dict(ckpt) + model.to(device=device, dtype=dtype) + return model + + @classmethod + def from_pretrained( + cls, + model_path, + device='cuda', + dtype=torch.float16, + use_safetensors=True, + variant='fp16', + subfolder='hunyuan3d-vae-v2-0', + **kwargs, + ): + config_path, ckpt_path = smart_load_model( + model_path, + subfolder=subfolder, + use_safetensors=use_safetensors, + variant=variant + ) + + return cls.from_single_file( + ckpt_path, + config_path, + device=device, + dtype=dtype, + use_safetensors=use_safetensors, + **kwargs + ) + + def __init__( + self, + volume_decoder=None, + surface_extractor=None + ): + super().__init__() + if volume_decoder is None: + volume_decoder = VanillaVolumeDecoder() + if surface_extractor is None: + surface_extractor = MCSurfaceExtractor() + self.volume_decoder = volume_decoder + self.surface_extractor = surface_extractor + + def latents2mesh(self, latents: torch.FloatTensor, **kwargs): + with synchronize_timer('Volume decoding'): + grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs) + with synchronize_timer('Surface extraction'): + outputs = self.surface_extractor(grid_logits, **kwargs) + return outputs + + def enable_flashvdm_decoder( + self, + enabled: bool = True, + adaptive_kv_selection=True, + topk_mode='mean', + mc_algo='dmc', + ): + if enabled: + if adaptive_kv_selection: + self.volume_decoder = FlashVDMVolumeDecoding(topk_mode) + else: + self.volume_decoder = HierarchicalVolumeDecoding() + if mc_algo not in SurfaceExtractors.keys(): + raise ValueError(f'Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}') + self.surface_extractor = SurfaceExtractors[mc_algo]() + else: + self.volume_decoder = VanillaVolumeDecoder() + self.surface_extractor = MCSurfaceExtractor() + + +class ShapeVAE(VectsetVAE): + def __init__( + self, + *, + num_latents: int, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = nn.Linear(embed_dim, width) + + self.transformer = Transformer( + n_ctx=num_latents, + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.scale_factor = scale_factor + self.latent_shape = (num_latents, embed_dim) + + def forward(self, latents): + latents = self.post_kl(latents) + latents = self.transformer(latents) + return latents diff --git a/hy3dgen/shapegen/models/autoencoders/surface_extractors.py b/hy3dgen/shapegen/models/autoencoders/surface_extractors.py new file mode 100644 index 0000000..f4d8f63 --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/surface_extractors.py @@ -0,0 +1,100 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from typing import Union, Tuple, List + +import numpy as np +import torch +from skimage import measure + + +class Latent2MeshOutput: + + def __init__(self, mesh_v=None, mesh_f=None): + self.mesh_v = mesh_v + self.mesh_f = mesh_f + + +def center_vertices(vertices): + """Translate the vertices so that bounding box is centered at zero.""" + vert_min = vertices.min(dim=0)[0] + vert_max = vertices.max(dim=0)[0] + vert_center = 0.5 * (vert_min + vert_max) + return vertices - vert_center + + +class SurfaceExtractor: + def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int): + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] + return grid_size, bbox_min, bbox_size + + def run(self, *args, **kwargs): + return NotImplementedError + + def __call__(self, grid_logits, **kwargs): + outputs = [] + for i in range(grid_logits.shape[0]): + try: + vertices, faces = self.run(grid_logits[i], **kwargs) + vertices = vertices.astype(np.float32) + faces = np.ascontiguousarray(faces) + outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces)) + + except Exception: + import traceback + traceback.print_exc() + outputs.append(None) + + return outputs + + +class MCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): + vertices, faces, normals, _ = measure.marching_cubes( + grid_logit.cpu().numpy(), + mc_level, + method="lewiner" + ) + grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution) + vertices = vertices / grid_size * bbox_size + bbox_min + return vertices, faces + + +class DMCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, octree_resolution, **kwargs): + device = grid_logit.device + if not hasattr(self, 'dmc'): + try: + from diso import DiffDMC + except: + raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'") + self.dmc = DiffDMC(dtype=torch.float32).to(device) + sdf = -grid_logit / octree_resolution + sdf = sdf.to(torch.float32).contiguous() + verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) + verts = center_vertices(verts) + vertices = verts.detach().cpu().numpy() + faces = faces.detach().cpu().numpy()[:, ::-1] + return vertices, faces + + +SurfaceExtractors = { + 'mc': MCSurfaceExtractor, + 'dmc': DMCSurfaceExtractor, +} diff --git a/hy3dgen/shapegen/models/autoencoders/volume_decoders.py b/hy3dgen/shapegen/models/autoencoders/volume_decoders.py new file mode 100644 index 0000000..d7bfd84 --- /dev/null +++ b/hy3dgen/shapegen/models/autoencoders/volume_decoders.py @@ -0,0 +1,435 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from typing import Union, Tuple, List, Callable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from tqdm import tqdm + +from .attention_blocks import CrossAttentionDecoder +from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor +from ...utils import logger + + +def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float): + device = input_tensor.device + D = input_tensor.shape[0] + signed_val = 0.0 + + # 添加偏移并处理无效值 + val = input_tensor + alpha + valid_mask = val > -9000 # 假设-9000是无效值 + + # 改进的邻居获取函数(保持维度一致) + def get_neighbor(t, shift, axis): + """根据指定轴进行位移并保持维度一致""" + if shift == 0: + return t.clone() + + # 确定填充轴(输入为[D, D, D]对应z,y,x轴) + pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后] + + # 根据轴类型设置填充 + if axis == 0: # x轴(最后一个维度) + pad_idx = 0 if shift > 0 else 1 + pad_dims[pad_idx] = abs(shift) + elif axis == 1: # y轴(中间维度) + pad_idx = 2 if shift > 0 else 3 + pad_dims[pad_idx] = abs(shift) + elif axis == 2: # z轴(第一个维度) + pad_idx = 4 if shift > 0 else 5 + pad_dims[pad_idx] = abs(shift) + + # 执行填充(添加batch和channel维度适配F.pad) + padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad + + # 构建动态切片索引 + slice_dims = [slice(None)] * 3 # 初始化为全切片 + if axis == 0: # x轴(dim=2) + if shift > 0: + slice_dims[0] = slice(shift, None) + else: + slice_dims[0] = slice(None, shift) + elif axis == 1: # y轴(dim=1) + if shift > 0: + slice_dims[1] = slice(shift, None) + else: + slice_dims[1] = slice(None, shift) + elif axis == 2: # z轴(dim=0) + if shift > 0: + slice_dims[2] = slice(shift, None) + else: + slice_dims[2] = slice(None, shift) + + # 应用切片并恢复维度 + padded = padded.squeeze(0).squeeze(0) + sliced = padded[slice_dims] + return sliced + + # 获取各方向邻居(确保维度一致) + left = get_neighbor(val, 1, axis=0) # x方向 + right = get_neighbor(val, -1, axis=0) + back = get_neighbor(val, 1, axis=1) # y方向 + front = get_neighbor(val, -1, axis=1) + down = get_neighbor(val, 1, axis=2) # z方向 + up = get_neighbor(val, -1, axis=2) + + # 处理边界无效值(使用where保持维度一致) + def safe_where(neighbor): + return torch.where(neighbor > -9000, neighbor, val) + + left = safe_where(left) + right = safe_where(right) + back = safe_where(back) + front = safe_where(front) + down = safe_where(down) + up = safe_where(up) + + # 计算符号一致性(转换为float32确保精度) + sign = torch.sign(val.to(torch.float32)) + neighbors_sign = torch.stack([ + torch.sign(left.to(torch.float32)), + torch.sign(right.to(torch.float32)), + torch.sign(back.to(torch.float32)), + torch.sign(front.to(torch.float32)), + torch.sign(down.to(torch.float32)), + torch.sign(up.to(torch.float32)) + ], dim=0) + + # 检查所有符号是否一致 + same_sign = torch.all(neighbors_sign == sign, dim=0) + + # 生成最终掩码 + mask = (~same_sign).to(torch.int32) + return mask * valid_mask.to(torch.int32) + + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class HierarchicalVolumeDecoding: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + mc_level: float = 0.0, + octree_resolution: int = None, + min_resolution: int = 63, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + + resolutions = [] + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij" + ) + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device)) + + grid_size = np.array(grid_size) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + batch_size = latents.shape[0] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"): + queries = xyz_samples[start: start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=batch_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])) + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device) + curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level) + curr_points += grid_logits.squeeze(0).abs() < 0.95 + + if octree_depth_now == resolutions[-1]: + expand_num = 0 + else: + expand_num = 1 + for i in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for i in range(2 - expand_num): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + + next_points = torch.stack(nidx, dim=1) + next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) + + torch.tensor(bbox_min, dtype=next_points.dtype, device=device)) + batch_logits = [] + for start in tqdm(range(0, next_points.shape[0], num_chunks), + desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"): + queries = next_points[start: start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents) + batch_logits.append(logits) + grid_logits = torch.cat(batch_logits, dim=1) + next_logits[nidx] = grid_logits[0, ..., 0] + grid_logits = next_logits.unsqueeze(0) + grid_logits[grid_logits == -10000.] = float('nan') + + return grid_logits + + +class FlashVDMVolumeDecoding: + def __init__(self, topk_mode='mean'): + if topk_mode not in ['mean', 'merge']: + raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}') + + if topk_mode == 'mean': + self.processor = FlashVDMCrossAttentionProcessor() + else: + self.processor = FlashVDMTopMCrossAttentionProcessor() + + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: CrossAttentionDecoder, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + mc_level: float = 0.0, + octree_resolution: int = None, + min_resolution: int = 63, + mini_grid_num: int = 4, + enable_pbar: bool = True, + **kwargs, + ): + processor = self.processor + geo_decoder.set_cross_attention_processor(processor) + + device = latents.device + dtype = latents.dtype + + resolutions = [] + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1 + for i, resolution in enumerate(resolutions[1:]): + resolutions[i + 1] = resolutions[0] * 2 ** (i + 1) + + logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}") + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij" + ) + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device)) + + grid_size = np.array(grid_size) + + # 2. latents to 3d volume + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype) + batch_size = latents.shape[0] + mini_grid_size = xyz_samples.shape[0] // mini_grid_num + xyz_samples = xyz_samples.view( + mini_grid_num, mini_grid_size, + mini_grid_num, mini_grid_size, + mini_grid_num, mini_grid_size, 3 + ).permute( + 0, 2, 4, 1, 3, 5, 6 + ).reshape( + -1, mini_grid_size * mini_grid_size * mini_grid_size, 3 + ) + batch_logits = [] + num_batchs = max(num_chunks // xyz_samples.shape[1], 1) + for start in tqdm(range(0, xyz_samples.shape[0], num_batchs), + desc=f"FlashVDM Volume Decoding", disable=not enable_pbar): + queries = xyz_samples[start: start + num_batchs, :] + batch = queries.shape[0] + batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch) + processor.topk = True + logits = geo_decoder(queries=queries, latents=batch_latents) + batch_logits.append(logits) + grid_logits = torch.cat(batch_logits, dim=0).reshape( + mini_grid_num, mini_grid_num, mini_grid_num, + mini_grid_size, mini_grid_size, + mini_grid_size + ).permute(0, 3, 1, 4, 2, 5).contiguous().view( + (batch_size, grid_size[0], grid_size[1], grid_size[2]) + ) + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device) + curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level) + curr_points += grid_logits.squeeze(0).abs() < 0.95 + + if octree_depth_now == resolutions[-1]: + expand_num = 0 + else: + expand_num = 1 + for i in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) + + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for i in range(2 - expand_num): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + + next_points = torch.stack(nidx, dim=1) + next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) + + torch.tensor(bbox_min, dtype=torch.float32, device=device)) + + query_grid_num = 6 + min_val = next_points.min(axis=0).values + max_val = next_points.max(axis=0).values + vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001) + index = torch.floor(vol_queries_index).long() + index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2] + index = index.sort() + next_points = next_points[index.indices].unsqueeze(0).contiguous() + unique_values = torch.unique(index.values, return_counts=True) + grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device) + input_grid = [[], []] + logits_grid_list = [] + start_num = 0 + sum_num = 0 + for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()): + if sum_num + count < num_chunks or sum_num == 0: + sum_num += count + input_grid[0].append(grid_index) + input_grid[1].append(count) + else: + processor.topk = input_grid + logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents) + start_num = start_num + sum_num + logits_grid_list.append(logits_grid) + input_grid = [[grid_index], [count]] + sum_num = count + if sum_num > 0: + processor.topk = input_grid + logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents) + logits_grid_list.append(logits_grid) + logits_grid = torch.cat(logits_grid_list, dim=1) + grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1) + next_logits[nidx] = grid_logits + grid_logits = next_logits.unsqueeze(0) + + grid_logits[grid_logits == -10000.] = float('nan') + + return grid_logits diff --git a/hy3dgen/shapegen/models/vae.py b/hy3dgen/shapegen/models/vae_old.py old mode 100755 new mode 100644 similarity index 100% rename from hy3dgen/shapegen/models/vae.py rename to hy3dgen/shapegen/models/vae_old.py diff --git a/hy3dgen/shapegen/pipelines.py b/hy3dgen/shapegen/pipelines.py index 9b043a1..462dd28 100755 --- a/hy3dgen/shapegen/pipelines.py +++ b/hy3dgen/shapegen/pipelines.py @@ -204,7 +204,7 @@ class Hunyuan3DDiTPipeline: config['model']['params']['guidance_embed'] = True config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True config['model']['params']['attention_mode'] = attention_mode - config['vae']['params']['attention_mode'] = attention_mode + #config['vae']['params']['attention_mode'] = attention_mode if cublas_ops: config['vae']['params']['cublas_ops'] = True diff --git a/hy3dgen/shapegen/postprocessors.py b/hy3dgen/shapegen/postprocessors.py index 477628d..55b5ffc 100755 --- a/hy3dgen/shapegen/postprocessors.py +++ b/hy3dgen/shapegen/postprocessors.py @@ -1,13 +1,3 @@ -# Open Source Model Licensed under the Apache License Version 2.0 -# and Other Licenses of the Third-Party Components therein: -# The below Model in this distribution may have been modified by THL A29 Limited -# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. - -# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. -# The below software and/or models in this distribution may have been -# modified by THL A29 Limited ("Tencent Modifications"). -# All Tencent Modifications are Copyright (C) THL A29 Limited. - # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined @@ -25,11 +15,12 @@ import tempfile import os from typing import Union - +import torch +import numpy as np import pymeshlab import trimesh -from .models.vae import Latent2MeshOutput +from .models.autoencoders import Latent2MeshOutput import folder_paths @@ -44,6 +35,9 @@ def load_mesh(path): def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000): + if max_facenum > mesh.current_mesh().face_number(): + return mesh + mesh.apply_filter( "meshing_decimation_quadric_edge_collapse", targetfacenum=max_facenum, @@ -287,3 +281,47 @@ class DegenerateFaceRemover: mesh = export_mesh(mesh, ms) return mesh + + +def mesh_normalize(mesh): + """ + Normalize mesh vertices to sphere + """ + scale_factor = 1.2 + vtx_pos = np.asarray(mesh.vertices) + max_bb = (vtx_pos - 0).max(0)[0] + min_bb = (vtx_pos - 0).min(0)[0] + + center = (max_bb + min_bb) / 2 + + scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0 + + vtx_pos = (vtx_pos - center) * (scale_factor / float(scale)) + mesh.vertices = vtx_pos + + return mesh + + +class MeshSimplifier: + def __init__(self, executable: str = None): + if executable is None: + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin") + self.executable = executable + + def __call__( + self, + mesh: Union[trimesh.Trimesh], + ) -> Union[trimesh.Trimesh]: + with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input: + with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output: + mesh.export(temp_input.name) + os.system(f'{self.executable} {temp_input.name} {temp_output.name}') + ms = trimesh.load(temp_output.name, process=False) + if isinstance(ms, trimesh.Scene): + combined_mesh = trimesh.Trimesh() + for geom in ms.geometry.values(): + combined_mesh = trimesh.util.concatenate([combined_mesh, geom]) + ms = combined_mesh + ms = mesh_normalize(ms) + return ms diff --git a/hy3dgen/shapegen/utils.py b/hy3dgen/shapegen/utils.py new file mode 100644 index 0000000..3d215bd --- /dev/null +++ b/hy3dgen/shapegen/utils.py @@ -0,0 +1,123 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import logging +import os +from functools import wraps + +import torch + + +def get_logger(name): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + return logger + + +logger = get_logger('hy3dgen.shapgen') + + +class synchronize_timer: + """ Synchronized timer to count the inference time of `nn.Module.forward`. + + Supports both context manager and decorator usage. + + Example as context manager: + ```python + with synchronize_timer('name') as t: + run() + ``` + + Example as decorator: + ```python + @synchronize_timer('Export to trimesh') + def export_to_trimesh(mesh_output): + pass + ``` + """ + + def __init__(self, name=None): + self.name = name + + def __enter__(self): + """Context manager entry: start timing.""" + if os.environ.get('HY3DGEN_DEBUG', '0') == '1': + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + return lambda: self.time + + def __exit__(self, exc_type, exc_value, exc_tb): + """Context manager exit: stop timing and log results.""" + if os.environ.get('HY3DGEN_DEBUG', '0') == '1': + self.end.record() + torch.cuda.synchronize() + self.time = self.start.elapsed_time(self.end) + if self.name is not None: + logger.info(f'{self.name} takes {self.time} ms') + + def __call__(self, func): + """Decorator: wrap the function to time its execution.""" + + @wraps(func) + def wrapper(*args, **kwargs): + with self: + result = func(*args, **kwargs) + return result + + return wrapper + + +def smart_load_model( + model_path, + subfolder, + use_safetensors, + variant, +): + original_model_path = model_path + # try local path + base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') + model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) + logger.info(f'Try to load model from local path: {model_path}') + if not os.path.exists(model_path): + logger.info('Model path not exists, try to download from huggingface') + try: + import huggingface_hub + # download from huggingface + path = huggingface_hub.snapshot_download(repo_id=original_model_path) + model_path = os.path.join(path, subfolder) + except ImportError: + logger.warning( + "You need to install HuggingFace Hub to load models from the hub." + ) + raise RuntimeError(f"Model path {model_path} not found") + except Exception as e: + raise e + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model path {original_model_path} not found") + + extension = 'ckpt' if not use_safetensors else 'safetensors' + variant = '' if variant is None else f'.{variant}' + ckpt_name = f'model{variant}.{extension}' + config_path = os.path.join(model_path, 'config.yaml') + ckpt_path = os.path.join(model_path, ckpt_name) + return config_path, ckpt_path diff --git a/nodes.py b/nodes.py index dfecc3f..e7f6821 100644 --- a/nodes.py +++ b/nodes.py @@ -1194,8 +1194,12 @@ class Hy3DVAEDecode: "octree_resolution": ("INT", {"default": 384, "min": 8, "max": 4096, "step": 8}), "num_chunks": ("INT", {"default": 8000, "min": 1, "max": 10000000, "step": 1, "tooltip": "Number of chunks to process at once, higher values use more memory, but make the process faster"}), "mc_level": ("FLOAT", {"default": 0, "min": -1.0, "max": 1.0, "step": 0.0001}), - "mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}), + #"mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}), + "mc_algo": (["mc", "dmc"], {"default": "mc"}), }, + "optional": { + "enable_flash_vdm": ("BOOLEAN", {"default": True}), + } } RETURN_TYPES = ("TRIMESH",) @@ -1203,11 +1207,14 @@ class Hy3DVAEDecode: FUNCTION = "process" CATEGORY = "Hunyuan3DWrapper" - def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo): + def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo, enable_flash_vdm=True): device = mm.get_torch_device() offload_device = mm.unet_offload_device() vae.to(device) + + vae.enable_flashvdm_decoder(enabled=enable_flash_vdm) + latents = 1. / vae.scale_factor * latents latents = vae(latents)