mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
Add FlashVDM vae
...and it earned that name, whoa! https://github.com/Tencent/FlashVDM/
This commit is contained in:
parent
1ecc8e3195
commit
280ab59834
@ -25,4 +25,4 @@
|
||||
|
||||
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
||||
from .hunyuan3ddit import Hunyuan3DDiT
|
||||
from .vae import ShapeVAE
|
||||
from .autoencoders import ShapeVAE
|
||||
|
||||
20
hy3dgen/shapegen/models/autoencoders/__init__.py
Normal file
20
hy3dgen/shapegen/models/autoencoders/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from .attention_blocks import CrossAttentionDecoder
|
||||
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
|
||||
FlashVDMTopMCrossAttentionProcessor
|
||||
from .model import ShapeVAE, VectsetVAE
|
||||
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
|
||||
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
|
||||
493
hy3dgen/shapegen/models/autoencoders/attention_blocks.py
Normal file
493
hy3dgen/shapegen/models/autoencoders/attention_blocks.py
Normal file
@ -0,0 +1,493 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from .attention_processors import CrossAttentionProcessor
|
||||
from ...utils import logger
|
||||
|
||||
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
||||
|
||||
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||
scaled_dot_product_attention = sageattn
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
], here f_i is the frequency.
|
||||
|
||||
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||
|
||||
Args:
|
||||
num_freqs (int): the number of frequencies, default is 6;
|
||||
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||
input_dim (int): the input dimension, default is 3;
|
||||
include_input (bool): include the input tensor or not, default is True.
|
||||
|
||||
Attributes:
|
||||
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||
|
||||
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||
otherwise, it is input_dim * num_freqs * 2.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_freqs: int = 6,
|
||||
logspace: bool = True,
|
||||
input_dim: int = 3,
|
||||
include_input: bool = True,
|
||||
include_pi: bool = True) -> None:
|
||||
|
||||
"""The initialization"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (num_freqs - 1),
|
||||
num_freqs,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if include_pi:
|
||||
frequencies *= torch.pi
|
||||
|
||||
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||
self.include_input = include_input
|
||||
self.num_freqs = num_freqs
|
||||
|
||||
self.out_dim = self.get_dims(input_dim)
|
||||
|
||||
def get_dims(self, input_dim):
|
||||
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||
|
||||
return out_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
""" Forward process.
|
||||
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
|
||||
Returns:
|
||||
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||
where temp is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
|
||||
if self.num_freqs > 0:
|
||||
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
||||
if self.include_input:
|
||||
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if self.drop_prob == 0. or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, *,
|
||||
width: int,
|
||||
expand_ratio: int = 4,
|
||||
output_width: int = None,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.c_fc = nn.Linear(width, width * expand_ratio)
|
||||
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||
self.gelu = nn.GELU()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_data: Optional[int] = None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_data = n_data
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
self.attn_processor = CrossAttentionProcessor()
|
||||
|
||||
def forward(self, q, kv):
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
attn_ch = width // self.heads // 2
|
||||
q = q.view(bs, n_ctx, self.heads, -1)
|
||||
kv = kv.view(bs, n_data, self.heads, -1)
|
||||
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = self.attn_processor(self, q, k, v)
|
||||
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
n_data: Optional[int] = None,
|
||||
data_width: Optional[int] = None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
kv_cache: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_data = n_data
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.data_width = width if data_width is None else data_width
|
||||
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
||||
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.attention = QKVMultiheadCrossAttention(
|
||||
heads=heads,
|
||||
n_data=n_data,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.kv_cache = kv_cache
|
||||
self.data = None
|
||||
|
||||
def forward(self, x, data):
|
||||
x = self.c_q(x)
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logger.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
x = self.attention(x, data)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_data: Optional[int] = None,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
data_width: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_width is None:
|
||||
data_width = width
|
||||
|
||||
self.attn = MultiheadCrossAttention(
|
||||
n_data=n_data,
|
||||
width=width,
|
||||
heads=heads,
|
||||
data_width=data_width,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class QKVMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heads: int,
|
||||
n_ctx: int,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.n_ctx = n_ctx
|
||||
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, qkv):
|
||||
bs, n_ctx, width = qkv.shape
|
||||
attn_ch = width // self.heads // 3
|
||||
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.heads = heads
|
||||
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
||||
self.c_proj = nn.Linear(width, width)
|
||||
self.attention = QKVMultiheadAttention(
|
||||
heads=heads,
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_qkv(x)
|
||||
x = self.attention(x)
|
||||
x = self.drop_path(self.c_proj(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn = MultiheadAttention(
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_ctx: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
drop_path_rate: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.n_ctx = n_ctx
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
n_ctx=n_ctx,
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
for block in self.resblocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_latents: int,
|
||||
out_channels: int,
|
||||
fourier_embedder: FourierEmbedder,
|
||||
width: int,
|
||||
heads: int,
|
||||
mlp_expand_ratio: int = 4,
|
||||
downsample_ratio: int = 1,
|
||||
enable_ln_post: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enable_ln_post = enable_ln_post
|
||||
self.fourier_embedder = fourier_embedder
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
||||
if self.downsample_ratio != 1:
|
||||
self.latents_proj = nn.Linear(width * downsample_ratio, width)
|
||||
if self.enable_ln_post == False:
|
||||
qk_norm = False
|
||||
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||
n_data=num_latents,
|
||||
width=width,
|
||||
mlp_expand_ratio=mlp_expand_ratio,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
if self.enable_ln_post:
|
||||
self.ln_post = nn.LayerNorm(width)
|
||||
self.output_proj = nn.Linear(width, out_channels)
|
||||
self.label_type = label_type
|
||||
self.count = 0
|
||||
|
||||
def set_cross_attention_processor(self, processor):
|
||||
self.cross_attn_decoder.attn.attention.attn_processor = processor
|
||||
|
||||
def set_default_cross_attention_processor(self):
|
||||
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
|
||||
|
||||
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||
if query_embeddings is None:
|
||||
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||
self.count += query_embeddings.shape[1]
|
||||
if self.downsample_ratio != 1:
|
||||
latents = self.latents_proj(latents)
|
||||
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||
if self.enable_ln_post:
|
||||
x = self.ln_post(x)
|
||||
occ = self.output_proj(x)
|
||||
return occ
|
||||
96
hy3dgen/shapegen/models/autoencoders/attention_processors.py
Normal file
96
hy3dgen/shapegen/models/autoencoders/attention_processors.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
scaled_dot_product_attention = F.scaled_dot_product_attention
|
||||
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ImportError:
|
||||
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||
scaled_dot_product_attention = sageattn
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class FlashVDMCrossAttentionProcessor:
|
||||
def __init__(self, topk=None):
|
||||
self.topk = topk
|
||||
|
||||
def __call__(self, attn, q, k, v):
|
||||
if k.shape[-2] == 3072:
|
||||
topk = 1024
|
||||
elif k.shape[-2] == 512:
|
||||
topk = 256
|
||||
else:
|
||||
topk = k.shape[-2] // 3
|
||||
|
||||
if self.topk is True:
|
||||
q1 = q[:, :, ::100, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
sim = torch.mean(sim, -2)
|
||||
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||
out = scaled_dot_product_attention(q, k0, v0)
|
||||
elif self.topk is False:
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
idx, counts = self.topk
|
||||
start = 0
|
||||
outs = []
|
||||
for grid_coord, count in zip(idx, counts):
|
||||
end = start + count
|
||||
q_chunk = q[:, :, start:end, :]
|
||||
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
|
||||
out = scaled_dot_product_attention(q_chunk, k0, v0)
|
||||
outs.append(out)
|
||||
start += count
|
||||
out = torch.cat(outs, dim=-2)
|
||||
self.topk = False
|
||||
return out
|
||||
|
||||
def select_topkv(self, q_chunk, k, v, topk):
|
||||
q1 = q_chunk[:, :, ::50, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
sim = torch.mean(sim, -2)
|
||||
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||
return k0, v0
|
||||
|
||||
|
||||
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
|
||||
def select_topkv(self, q_chunk, k, v, topk):
|
||||
q1 = q_chunk[:, :, ::30, :]
|
||||
sim = q1 @ k.transpose(-1, -2)
|
||||
# sim = sim.to(torch.float32)
|
||||
sim = sim.softmax(-1)
|
||||
sim = torch.mean(sim, 1)
|
||||
activated_token = torch.where(sim > 1e-6)[2]
|
||||
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
||||
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
|
||||
v0 = torch.gather(v, dim=-2, index=index)
|
||||
k0 = torch.gather(k, dim=-2, index=index)
|
||||
return k0, v0
|
||||
189
hy3dgen/shapegen/models/autoencoders/model.py
Normal file
189
hy3dgen/shapegen/models/autoencoders/model.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
|
||||
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
|
||||
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
||||
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
|
||||
from ...utils import logger, synchronize_timer, smart_load_model
|
||||
|
||||
|
||||
class VectsetVAE(nn.Module):
|
||||
|
||||
@classmethod
|
||||
@synchronize_timer('VectsetVAE Model Loading')
|
||||
def from_single_file(
|
||||
cls,
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
**kwargs,
|
||||
):
|
||||
# load config
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# load ckpt
|
||||
if use_safetensors:
|
||||
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
if use_safetensors:
|
||||
import safetensors.torch
|
||||
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
|
||||
model_kwargs = config['params']
|
||||
model_kwargs.update(kwargs)
|
||||
|
||||
model = cls(**model_kwargs)
|
||||
model.load_state_dict(ckpt)
|
||||
model.to(device=device, dtype=dtype)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path,
|
||||
device='cuda',
|
||||
dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant='fp16',
|
||||
subfolder='hunyuan3d-vae-v2-0',
|
||||
**kwargs,
|
||||
):
|
||||
config_path, ckpt_path = smart_load_model(
|
||||
model_path,
|
||||
subfolder=subfolder,
|
||||
use_safetensors=use_safetensors,
|
||||
variant=variant
|
||||
)
|
||||
|
||||
return cls.from_single_file(
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_safetensors=use_safetensors,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_decoder=None,
|
||||
surface_extractor=None
|
||||
):
|
||||
super().__init__()
|
||||
if volume_decoder is None:
|
||||
volume_decoder = VanillaVolumeDecoder()
|
||||
if surface_extractor is None:
|
||||
surface_extractor = MCSurfaceExtractor()
|
||||
self.volume_decoder = volume_decoder
|
||||
self.surface_extractor = surface_extractor
|
||||
|
||||
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
||||
with synchronize_timer('Volume decoding'):
|
||||
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
|
||||
with synchronize_timer('Surface extraction'):
|
||||
outputs = self.surface_extractor(grid_logits, **kwargs)
|
||||
return outputs
|
||||
|
||||
def enable_flashvdm_decoder(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
adaptive_kv_selection=True,
|
||||
topk_mode='mean',
|
||||
mc_algo='dmc',
|
||||
):
|
||||
if enabled:
|
||||
if adaptive_kv_selection:
|
||||
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
|
||||
else:
|
||||
self.volume_decoder = HierarchicalVolumeDecoding()
|
||||
if mc_algo not in SurfaceExtractors.keys():
|
||||
raise ValueError(f'Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}')
|
||||
self.surface_extractor = SurfaceExtractors[mc_algo]()
|
||||
else:
|
||||
self.volume_decoder = VanillaVolumeDecoder()
|
||||
self.surface_extractor = MCSurfaceExtractor()
|
||||
|
||||
|
||||
class ShapeVAE(VectsetVAE):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_latents: int,
|
||||
embed_dim: int,
|
||||
width: int,
|
||||
heads: int,
|
||||
num_decoder_layers: int,
|
||||
geo_decoder_downsample_ratio: int = 1,
|
||||
geo_decoder_mlp_expand_ratio: int = 4,
|
||||
geo_decoder_ln_post: bool = True,
|
||||
num_freqs: int = 8,
|
||||
include_pi: bool = True,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
label_type: str = "binary",
|
||||
drop_path_rate: float = 0.0,
|
||||
scale_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.post_kl = nn.Linear(embed_dim, width)
|
||||
|
||||
self.transformer = Transformer(
|
||||
n_ctx=num_latents,
|
||||
width=width,
|
||||
layers=num_decoder_layers,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
drop_path_rate=drop_path_rate
|
||||
)
|
||||
|
||||
self.geo_decoder = CrossAttentionDecoder(
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
out_channels=1,
|
||||
num_latents=num_latents,
|
||||
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||
downsample_ratio=geo_decoder_downsample_ratio,
|
||||
enable_ln_post=self.geo_decoder_ln_post,
|
||||
width=width // geo_decoder_downsample_ratio,
|
||||
heads=heads // geo_decoder_downsample_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
label_type=label_type,
|
||||
)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.latent_shape = (num_latents, embed_dim)
|
||||
|
||||
def forward(self, latents):
|
||||
latents = self.post_kl(latents)
|
||||
latents = self.transformer(latents)
|
||||
return latents
|
||||
100
hy3dgen/shapegen/models/autoencoders/surface_extractors.py
Normal file
100
hy3dgen/shapegen/models/autoencoders/surface_extractors.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from skimage import measure
|
||||
|
||||
|
||||
class Latent2MeshOutput:
|
||||
|
||||
def __init__(self, mesh_v=None, mesh_f=None):
|
||||
self.mesh_v = mesh_v
|
||||
self.mesh_f = mesh_f
|
||||
|
||||
|
||||
def center_vertices(vertices):
|
||||
"""Translate the vertices so that bounding box is centered at zero."""
|
||||
vert_min = vertices.min(dim=0)[0]
|
||||
vert_max = vertices.max(dim=0)[0]
|
||||
vert_center = 0.5 * (vert_min + vert_max)
|
||||
return vertices - vert_center
|
||||
|
||||
|
||||
class SurfaceExtractor:
|
||||
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
return grid_size, bbox_min, bbox_size
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return NotImplementedError
|
||||
|
||||
def __call__(self, grid_logits, **kwargs):
|
||||
outputs = []
|
||||
for i in range(grid_logits.shape[0]):
|
||||
try:
|
||||
vertices, faces = self.run(grid_logits[i], **kwargs)
|
||||
vertices = vertices.astype(np.float32)
|
||||
faces = np.ascontiguousarray(faces)
|
||||
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
|
||||
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
outputs.append(None)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MCSurfaceExtractor(SurfaceExtractor):
|
||||
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
|
||||
vertices, faces, normals, _ = measure.marching_cubes(
|
||||
grid_logit.cpu().numpy(),
|
||||
mc_level,
|
||||
method="lewiner"
|
||||
)
|
||||
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
|
||||
vertices = vertices / grid_size * bbox_size + bbox_min
|
||||
return vertices, faces
|
||||
|
||||
|
||||
class DMCSurfaceExtractor(SurfaceExtractor):
|
||||
def run(self, grid_logit, *, octree_resolution, **kwargs):
|
||||
device = grid_logit.device
|
||||
if not hasattr(self, 'dmc'):
|
||||
try:
|
||||
from diso import DiffDMC
|
||||
except:
|
||||
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
||||
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
||||
sdf = -grid_logit / octree_resolution
|
||||
sdf = sdf.to(torch.float32).contiguous()
|
||||
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
||||
verts = center_vertices(verts)
|
||||
vertices = verts.detach().cpu().numpy()
|
||||
faces = faces.detach().cpu().numpy()[:, ::-1]
|
||||
return vertices, faces
|
||||
|
||||
|
||||
SurfaceExtractors = {
|
||||
'mc': MCSurfaceExtractor,
|
||||
'dmc': DMCSurfaceExtractor,
|
||||
}
|
||||
435
hy3dgen/shapegen/models/autoencoders/volume_decoders.py
Normal file
435
hy3dgen/shapegen/models/autoencoders/volume_decoders.py
Normal file
@ -0,0 +1,435 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from typing import Union, Tuple, List, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from .attention_blocks import CrossAttentionDecoder
|
||||
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
||||
from ...utils import logger
|
||||
|
||||
|
||||
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
||||
device = input_tensor.device
|
||||
D = input_tensor.shape[0]
|
||||
signed_val = 0.0
|
||||
|
||||
# 添加偏移并处理无效值
|
||||
val = input_tensor + alpha
|
||||
valid_mask = val > -9000 # 假设-9000是无效值
|
||||
|
||||
# 改进的邻居获取函数(保持维度一致)
|
||||
def get_neighbor(t, shift, axis):
|
||||
"""根据指定轴进行位移并保持维度一致"""
|
||||
if shift == 0:
|
||||
return t.clone()
|
||||
|
||||
# 确定填充轴(输入为[D, D, D]对应z,y,x轴)
|
||||
pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
|
||||
|
||||
# 根据轴类型设置填充
|
||||
if axis == 0: # x轴(最后一个维度)
|
||||
pad_idx = 0 if shift > 0 else 1
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
elif axis == 1: # y轴(中间维度)
|
||||
pad_idx = 2 if shift > 0 else 3
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
elif axis == 2: # z轴(第一个维度)
|
||||
pad_idx = 4 if shift > 0 else 5
|
||||
pad_dims[pad_idx] = abs(shift)
|
||||
|
||||
# 执行填充(添加batch和channel维度适配F.pad)
|
||||
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
|
||||
|
||||
# 构建动态切片索引
|
||||
slice_dims = [slice(None)] * 3 # 初始化为全切片
|
||||
if axis == 0: # x轴(dim=2)
|
||||
if shift > 0:
|
||||
slice_dims[0] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[0] = slice(None, shift)
|
||||
elif axis == 1: # y轴(dim=1)
|
||||
if shift > 0:
|
||||
slice_dims[1] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[1] = slice(None, shift)
|
||||
elif axis == 2: # z轴(dim=0)
|
||||
if shift > 0:
|
||||
slice_dims[2] = slice(shift, None)
|
||||
else:
|
||||
slice_dims[2] = slice(None, shift)
|
||||
|
||||
# 应用切片并恢复维度
|
||||
padded = padded.squeeze(0).squeeze(0)
|
||||
sliced = padded[slice_dims]
|
||||
return sliced
|
||||
|
||||
# 获取各方向邻居(确保维度一致)
|
||||
left = get_neighbor(val, 1, axis=0) # x方向
|
||||
right = get_neighbor(val, -1, axis=0)
|
||||
back = get_neighbor(val, 1, axis=1) # y方向
|
||||
front = get_neighbor(val, -1, axis=1)
|
||||
down = get_neighbor(val, 1, axis=2) # z方向
|
||||
up = get_neighbor(val, -1, axis=2)
|
||||
|
||||
# 处理边界无效值(使用where保持维度一致)
|
||||
def safe_where(neighbor):
|
||||
return torch.where(neighbor > -9000, neighbor, val)
|
||||
|
||||
left = safe_where(left)
|
||||
right = safe_where(right)
|
||||
back = safe_where(back)
|
||||
front = safe_where(front)
|
||||
down = safe_where(down)
|
||||
up = safe_where(up)
|
||||
|
||||
# 计算符号一致性(转换为float32确保精度)
|
||||
sign = torch.sign(val.to(torch.float32))
|
||||
neighbors_sign = torch.stack([
|
||||
torch.sign(left.to(torch.float32)),
|
||||
torch.sign(right.to(torch.float32)),
|
||||
torch.sign(back.to(torch.float32)),
|
||||
torch.sign(front.to(torch.float32)),
|
||||
torch.sign(down.to(torch.float32)),
|
||||
torch.sign(up.to(torch.float32))
|
||||
], dim=0)
|
||||
|
||||
# 检查所有符号是否一致
|
||||
same_sign = torch.all(neighbors_sign == sign, dim=0)
|
||||
|
||||
# 生成最终掩码
|
||||
mask = (~same_sign).to(torch.int32)
|
||||
return mask * valid_mask.to(torch.int32)
|
||||
|
||||
|
||||
def generate_dense_grid_points(
|
||||
bbox_min: np.ndarray,
|
||||
bbox_max: np.ndarray,
|
||||
octree_resolution: int,
|
||||
indexing: str = "ij",
|
||||
):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = octree_resolution
|
||||
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class VanillaVolumeDecoder:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
octree_resolution: int = None,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
|
||||
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=octree_resolution,
|
||||
indexing="ij"
|
||||
)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class HierarchicalVolumeDecoding:
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: Callable,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
mc_level: float = 0.0,
|
||||
octree_resolution: int = None,
|
||||
min_resolution: int = 63,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
|
||||
resolutions = []
|
||||
if octree_resolution < min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
while octree_resolution >= min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
octree_resolution = octree_resolution // 2
|
||||
resolutions.reverse()
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=resolutions[0],
|
||||
indexing="ij"
|
||||
)
|
||||
|
||||
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||
|
||||
grid_size = np.array(grid_size)
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
batch_logits = []
|
||||
batch_size = latents.shape[0]
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
||||
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
|
||||
queries = xyz_samples[start: start + num_chunks, :]
|
||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=batch_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
||||
|
||||
for octree_depth_now in resolutions[1:]:
|
||||
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||
resolution = bbox_size / octree_depth_now
|
||||
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||
|
||||
if octree_depth_now == resolutions[-1]:
|
||||
expand_num = 0
|
||||
else:
|
||||
expand_num = 1
|
||||
for i in range(expand_num):
|
||||
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||
for i in range(2 - expand_num):
|
||||
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||
nidx = torch.where(next_index > 0)
|
||||
|
||||
next_points = torch.stack(nidx, dim=1)
|
||||
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
|
||||
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, next_points.shape[0], num_chunks),
|
||||
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
|
||||
queries = next_points[start: start + num_chunks, :]
|
||||
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
|
||||
batch_logits.append(logits)
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
next_logits[nidx] = grid_logits[0, ..., 0]
|
||||
grid_logits = next_logits.unsqueeze(0)
|
||||
grid_logits[grid_logits == -10000.] = float('nan')
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FlashVDMVolumeDecoding:
|
||||
def __init__(self, topk_mode='mean'):
|
||||
if topk_mode not in ['mean', 'merge']:
|
||||
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
|
||||
|
||||
if topk_mode == 'mean':
|
||||
self.processor = FlashVDMCrossAttentionProcessor()
|
||||
else:
|
||||
self.processor = FlashVDMTopMCrossAttentionProcessor()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
geo_decoder: CrossAttentionDecoder,
|
||||
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||
num_chunks: int = 10000,
|
||||
mc_level: float = 0.0,
|
||||
octree_resolution: int = None,
|
||||
min_resolution: int = 63,
|
||||
mini_grid_num: int = 4,
|
||||
enable_pbar: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
processor = self.processor
|
||||
geo_decoder.set_cross_attention_processor(processor)
|
||||
|
||||
device = latents.device
|
||||
dtype = latents.dtype
|
||||
|
||||
resolutions = []
|
||||
if octree_resolution < min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
while octree_resolution >= min_resolution:
|
||||
resolutions.append(octree_resolution)
|
||||
octree_resolution = octree_resolution // 2
|
||||
resolutions.reverse()
|
||||
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
||||
for i, resolution in enumerate(resolutions[1:]):
|
||||
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
||||
|
||||
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
|
||||
|
||||
# 1. generate query points
|
||||
if isinstance(bounds, float):
|
||||
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||
bbox_min=bbox_min,
|
||||
bbox_max=bbox_max,
|
||||
octree_resolution=resolutions[0],
|
||||
indexing="ij"
|
||||
)
|
||||
|
||||
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||
|
||||
grid_size = np.array(grid_size)
|
||||
|
||||
# 2. latents to 3d volume
|
||||
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
||||
batch_size = latents.shape[0]
|
||||
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
||||
xyz_samples = xyz_samples.view(
|
||||
mini_grid_num, mini_grid_size,
|
||||
mini_grid_num, mini_grid_size,
|
||||
mini_grid_num, mini_grid_size, 3
|
||||
).permute(
|
||||
0, 2, 4, 1, 3, 5, 6
|
||||
).reshape(
|
||||
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
||||
)
|
||||
batch_logits = []
|
||||
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
||||
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
|
||||
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
|
||||
queries = xyz_samples[start: start + num_batchs, :]
|
||||
batch = queries.shape[0]
|
||||
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
||||
processor.topk = True
|
||||
logits = geo_decoder(queries=queries, latents=batch_latents)
|
||||
batch_logits.append(logits)
|
||||
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
||||
mini_grid_num, mini_grid_num, mini_grid_num,
|
||||
mini_grid_size, mini_grid_size,
|
||||
mini_grid_size
|
||||
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
||||
)
|
||||
|
||||
for octree_depth_now in resolutions[1:]:
|
||||
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||
resolution = bbox_size / octree_depth_now
|
||||
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||
|
||||
if octree_depth_now == resolutions[-1]:
|
||||
expand_num = 0
|
||||
else:
|
||||
expand_num = 1
|
||||
for i in range(expand_num):
|
||||
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||
|
||||
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||
for i in range(2 - expand_num):
|
||||
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||
nidx = torch.where(next_index > 0)
|
||||
|
||||
next_points = torch.stack(nidx, dim=1)
|
||||
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
||||
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
||||
|
||||
query_grid_num = 6
|
||||
min_val = next_points.min(axis=0).values
|
||||
max_val = next_points.max(axis=0).values
|
||||
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
||||
index = torch.floor(vol_queries_index).long()
|
||||
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
||||
index = index.sort()
|
||||
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
||||
unique_values = torch.unique(index.values, return_counts=True)
|
||||
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
||||
input_grid = [[], []]
|
||||
logits_grid_list = []
|
||||
start_num = 0
|
||||
sum_num = 0
|
||||
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
||||
if sum_num + count < num_chunks or sum_num == 0:
|
||||
sum_num += count
|
||||
input_grid[0].append(grid_index)
|
||||
input_grid[1].append(count)
|
||||
else:
|
||||
processor.topk = input_grid
|
||||
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||
start_num = start_num + sum_num
|
||||
logits_grid_list.append(logits_grid)
|
||||
input_grid = [[grid_index], [count]]
|
||||
sum_num = count
|
||||
if sum_num > 0:
|
||||
processor.topk = input_grid
|
||||
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||
logits_grid_list.append(logits_grid)
|
||||
logits_grid = torch.cat(logits_grid_list, dim=1)
|
||||
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
||||
next_logits[nidx] = grid_logits
|
||||
grid_logits = next_logits.unsqueeze(0)
|
||||
|
||||
grid_logits[grid_logits == -10000.] = float('nan')
|
||||
|
||||
return grid_logits
|
||||
0
hy3dgen/shapegen/models/vae.py → hy3dgen/shapegen/models/vae_old.py
Executable file → Normal file
0
hy3dgen/shapegen/models/vae.py → hy3dgen/shapegen/models/vae_old.py
Executable file → Normal file
@ -204,7 +204,7 @@ class Hunyuan3DDiTPipeline:
|
||||
config['model']['params']['guidance_embed'] = True
|
||||
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
||||
config['model']['params']['attention_mode'] = attention_mode
|
||||
config['vae']['params']['attention_mode'] = attention_mode
|
||||
#config['vae']['params']['attention_mode'] = attention_mode
|
||||
|
||||
if cublas_ops:
|
||||
config['vae']['params']['cublas_ops'] = True
|
||||
|
||||
@ -1,13 +1,3 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0
|
||||
# and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
@ -25,11 +15,12 @@
|
||||
import tempfile
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pymeshlab
|
||||
import trimesh
|
||||
|
||||
from .models.vae import Latent2MeshOutput
|
||||
from .models.autoencoders import Latent2MeshOutput
|
||||
|
||||
import folder_paths
|
||||
|
||||
@ -44,6 +35,9 @@ def load_mesh(path):
|
||||
|
||||
|
||||
def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
|
||||
if max_facenum > mesh.current_mesh().face_number():
|
||||
return mesh
|
||||
|
||||
mesh.apply_filter(
|
||||
"meshing_decimation_quadric_edge_collapse",
|
||||
targetfacenum=max_facenum,
|
||||
@ -287,3 +281,47 @@ class DegenerateFaceRemover:
|
||||
|
||||
mesh = export_mesh(mesh, ms)
|
||||
return mesh
|
||||
|
||||
|
||||
def mesh_normalize(mesh):
|
||||
"""
|
||||
Normalize mesh vertices to sphere
|
||||
"""
|
||||
scale_factor = 1.2
|
||||
vtx_pos = np.asarray(mesh.vertices)
|
||||
max_bb = (vtx_pos - 0).max(0)[0]
|
||||
min_bb = (vtx_pos - 0).min(0)[0]
|
||||
|
||||
center = (max_bb + min_bb) / 2
|
||||
|
||||
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
|
||||
|
||||
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
|
||||
mesh.vertices = vtx_pos
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
class MeshSimplifier:
|
||||
def __init__(self, executable: str = None):
|
||||
if executable is None:
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
|
||||
self.executable = executable
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
mesh: Union[trimesh.Trimesh],
|
||||
) -> Union[trimesh.Trimesh]:
|
||||
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
|
||||
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
|
||||
mesh.export(temp_input.name)
|
||||
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
|
||||
ms = trimesh.load(temp_output.name, process=False)
|
||||
if isinstance(ms, trimesh.Scene):
|
||||
combined_mesh = trimesh.Trimesh()
|
||||
for geom in ms.geometry.values():
|
||||
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
||||
ms = combined_mesh
|
||||
ms = mesh_normalize(ms)
|
||||
return ms
|
||||
|
||||
123
hy3dgen/shapegen/utils.py
Normal file
123
hy3dgen/shapegen/utils.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
return logger
|
||||
|
||||
|
||||
logger = get_logger('hy3dgen.shapgen')
|
||||
|
||||
|
||||
class synchronize_timer:
|
||||
""" Synchronized timer to count the inference time of `nn.Module.forward`.
|
||||
|
||||
Supports both context manager and decorator usage.
|
||||
|
||||
Example as context manager:
|
||||
```python
|
||||
with synchronize_timer('name') as t:
|
||||
run()
|
||||
```
|
||||
|
||||
Example as decorator:
|
||||
```python
|
||||
@synchronize_timer('Export to trimesh')
|
||||
def export_to_trimesh(mesh_output):
|
||||
pass
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry: start timing."""
|
||||
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||
self.start = torch.cuda.Event(enable_timing=True)
|
||||
self.end = torch.cuda.Event(enable_timing=True)
|
||||
self.start.record()
|
||||
return lambda: self.time
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
"""Context manager exit: stop timing and log results."""
|
||||
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||
self.end.record()
|
||||
torch.cuda.synchronize()
|
||||
self.time = self.start.elapsed_time(self.end)
|
||||
if self.name is not None:
|
||||
logger.info(f'{self.name} takes {self.time} ms')
|
||||
|
||||
def __call__(self, func):
|
||||
"""Decorator: wrap the function to time its execution."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with self:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def smart_load_model(
|
||||
model_path,
|
||||
subfolder,
|
||||
use_safetensors,
|
||||
variant,
|
||||
):
|
||||
original_model_path = model_path
|
||||
# try local path
|
||||
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
||||
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
||||
logger.info(f'Try to load model from local path: {model_path}')
|
||||
if not os.path.exists(model_path):
|
||||
logger.info('Model path not exists, try to download from huggingface')
|
||||
try:
|
||||
import huggingface_hub
|
||||
# download from huggingface
|
||||
path = huggingface_hub.snapshot_download(repo_id=original_model_path)
|
||||
model_path = os.path.join(path, subfolder)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"You need to install HuggingFace Hub to load models from the hub."
|
||||
)
|
||||
raise RuntimeError(f"Model path {model_path} not found")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
||||
|
||||
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
||||
variant = '' if variant is None else f'.{variant}'
|
||||
ckpt_name = f'model{variant}.{extension}'
|
||||
config_path = os.path.join(model_path, 'config.yaml')
|
||||
ckpt_path = os.path.join(model_path, ckpt_name)
|
||||
return config_path, ckpt_path
|
||||
11
nodes.py
11
nodes.py
@ -1194,8 +1194,12 @@ class Hy3DVAEDecode:
|
||||
"octree_resolution": ("INT", {"default": 384, "min": 8, "max": 4096, "step": 8}),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1, "max": 10000000, "step": 1, "tooltip": "Number of chunks to process at once, higher values use more memory, but make the process faster"}),
|
||||
"mc_level": ("FLOAT", {"default": 0, "min": -1.0, "max": 1.0, "step": 0.0001}),
|
||||
"mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}),
|
||||
#"mc_algo": (["mc", "dmc", "odc", "none"], {"default": "mc"}),
|
||||
"mc_algo": (["mc", "dmc"], {"default": "mc"}),
|
||||
},
|
||||
"optional": {
|
||||
"enable_flash_vdm": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("TRIMESH",)
|
||||
@ -1203,11 +1207,14 @@ class Hy3DVAEDecode:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo):
|
||||
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo, enable_flash_vdm=True):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
vae.to(device)
|
||||
|
||||
vae.enable_flashvdm_decoder(enabled=enable_flash_vdm)
|
||||
|
||||
latents = 1. / vae.scale_factor * latents
|
||||
latents = vae(latents)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user