mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-29 00:36:31 +08:00
Compare commits
13 Commits
921b0d78a9
...
cbe2837e50
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbe2837e50 | ||
|
|
48c7716b2e | ||
|
|
2ceba3841b | ||
|
|
4a87c71089 | ||
|
|
c1f95d9a6e | ||
|
|
896f0a3531 | ||
|
|
c50d0b47b5 | ||
|
|
b8e2d2c800 | ||
|
|
04b4d350a0 | ||
|
|
5e684d2334 | ||
|
|
b432560c19 | ||
|
|
977213426f | ||
|
|
48dae15e4a |
82
configs/dit_config_2_1.yaml
Normal file
82
configs/dit_config_2_1.yaml
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
model:
|
||||||
|
target: .hy3dshape.hy3dshape.models.denoisers.hunyuandit.HunYuanDiTPlain
|
||||||
|
params:
|
||||||
|
input_size: &num_latents 4096
|
||||||
|
in_channels: 64
|
||||||
|
hidden_size: 2048
|
||||||
|
context_dim: 1024
|
||||||
|
depth: 21
|
||||||
|
num_heads: 16
|
||||||
|
qk_norm: true
|
||||||
|
text_len: 1370
|
||||||
|
with_decoupled_ca: false
|
||||||
|
use_attention_pooling: false
|
||||||
|
qk_norm_type: 'rms'
|
||||||
|
qkv_bias: false
|
||||||
|
use_pos_emb: false
|
||||||
|
num_moe_layers: 6
|
||||||
|
num_experts: 8
|
||||||
|
moe_top_k: 2
|
||||||
|
|
||||||
|
vae:
|
||||||
|
target: .hy3dshape.hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
params:
|
||||||
|
num_latents: *num_latents
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
num_encoder_layers: 8
|
||||||
|
num_decoder_layers: 16
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: 1.0039506158752403
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
point_feats: 4
|
||||||
|
pc_size: 81920
|
||||||
|
pc_sharpedge_size: 0
|
||||||
|
|
||||||
|
conditioner:
|
||||||
|
target: .hy3dshape.hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino large
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1024
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 16
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 24
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: false
|
||||||
|
image_size: 518
|
||||||
|
use_cls_token: true
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: .hy3dshape.hy3dshape.schedulers.FlowMatchEulerDiscreteScheduler
|
||||||
|
params:
|
||||||
|
num_train_timesteps: 1000
|
||||||
|
|
||||||
|
image_processor:
|
||||||
|
target: .hy3dshape.hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
border_ratio: 0.15
|
||||||
|
|
||||||
|
pipeline:
|
||||||
|
target: .hy3dshape.hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
@ -40,7 +40,7 @@ from tqdm import tqdm
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -124,13 +124,15 @@ def export_to_trimesh(mesh_output):
|
|||||||
|
|
||||||
|
|
||||||
def get_obj_from_str(string, reload=False):
|
def get_obj_from_str(string, reload=False):
|
||||||
package_directory_name = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
module, cls = string.rsplit(".", 1)
|
module, cls = string.rsplit(".", 1)
|
||||||
if reload:
|
if reload:
|
||||||
module_imp = importlib.import_module(module)
|
module_imp = importlib.import_module(module)
|
||||||
importlib.reload(module_imp)
|
importlib.reload(module_imp)
|
||||||
return getattr(importlib.import_module(module, package=package_directory_name), cls)
|
try:
|
||||||
|
obj = getattr(importlib.import_module(module, package=os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), cls)
|
||||||
|
except:
|
||||||
|
obj = getattr(importlib.import_module(module, package=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( __file__ ))))), cls)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config, **kwargs):
|
def instantiate_from_config(config, **kwargs):
|
||||||
@ -158,33 +160,21 @@ class Hunyuan3DDiTPipeline:
|
|||||||
scheduler="FlowMatchEulerDiscreteScheduler",
|
scheduler="FlowMatchEulerDiscreteScheduler",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
new_sd = {}
|
||||||
# load ckpt
|
sd = load_torch_file(ckpt_path)
|
||||||
if use_safetensors:
|
if ckpt_path.endswith('.safetensors'):
|
||||||
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
for key, value in sd.items():
|
||||||
if not os.path.exists(ckpt_path):
|
|
||||||
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
|
||||||
logger.info(f"Loading model from {ckpt_path}")
|
|
||||||
|
|
||||||
if use_safetensors:
|
|
||||||
# parse safetensors
|
|
||||||
import safetensors.torch
|
|
||||||
safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
|
||||||
ckpt = {}
|
|
||||||
for key, value in safetensors_ckpt.items():
|
|
||||||
model_name = key.split('.')[0]
|
model_name = key.split('.')[0]
|
||||||
new_key = key[len(model_name) + 1:]
|
new_key = key[len(model_name) + 1:]
|
||||||
if model_name not in ckpt:
|
if model_name not in new_sd:
|
||||||
ckpt[model_name] = {}
|
new_sd[model_name] = {}
|
||||||
ckpt[model_name][new_key] = value
|
new_sd[model_name][new_key] = value
|
||||||
else:
|
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
# load config
|
# load config
|
||||||
|
|
||||||
single_block_nums = set()
|
single_block_nums = set()
|
||||||
for k in ckpt["model"].keys():
|
for k in new_sd["model"].keys():
|
||||||
if k.startswith('single_blocks.'):
|
if k.startswith('single_blocks.'):
|
||||||
block_num = int(k.split('.')[1])
|
block_num = int(k.split('.')[1])
|
||||||
single_block_nums.add(block_num)
|
single_block_nums.add(block_num)
|
||||||
@ -199,7 +189,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
|
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
if "guidance_in.in_layer.bias" in ckpt['model']: #guidance_in.in_layer.bias
|
if "guidance_in.in_layer.bias" in new_sd['model']: #guidance_in.in_layer.bias
|
||||||
logger.info("Model has guidance_in, setting guidance_embed to True")
|
logger.info("Model has guidance_in, setting guidance_embed to True")
|
||||||
config['model']['params']['guidance_embed'] = True
|
config['model']['params']['guidance_embed'] = True
|
||||||
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
||||||
@ -215,15 +205,15 @@ class Hunyuan3DDiTPipeline:
|
|||||||
conditioner = instantiate_from_config(config['conditioner'])
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
#model
|
#model
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=new_sd['model'][name])
|
||||||
#vae
|
#vae
|
||||||
for name, param in vae.named_parameters():
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=new_sd['vae'][name])
|
||||||
|
|
||||||
if 'conditioner' in ckpt:
|
if 'conditioner' in new_sd:
|
||||||
#conditioner.load_state_dict(ckpt['conditioner'])
|
#conditioner.load_state_dict(ckpt['conditioner'])
|
||||||
for name, param in conditioner.named_parameters():
|
for name, param in conditioner.named_parameters():
|
||||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=new_sd['conditioner'][name])
|
||||||
|
|
||||||
image_processor = instantiate_from_config(config['image_processor'])
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
|
|
||||||
@ -255,49 +245,6 @@ class Hunyuan3DDiTPipeline:
|
|||||||
|
|
||||||
return cls(**model_kwargs), vae
|
return cls(**model_kwargs), vae
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# def from_pretrained(
|
|
||||||
# cls,
|
|
||||||
# model_path,
|
|
||||||
# ckpt_name='model.ckpt',
|
|
||||||
# config_name='config.yaml',
|
|
||||||
# device='cuda',
|
|
||||||
# dtype=torch.float16,
|
|
||||||
# use_safetensors=None,
|
|
||||||
# **kwargs,
|
|
||||||
# ):
|
|
||||||
# original_model_path = model_path
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# # try local path
|
|
||||||
# base_dir = "checkpoints"
|
|
||||||
# model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# try:
|
|
||||||
# import huggingface_hub
|
|
||||||
# # download from huggingface
|
|
||||||
# huggingface_hub.snapshot_download(
|
|
||||||
# repo_id="tencent/Hunyuan3D-2",
|
|
||||||
# local_dir=base_dir,)
|
|
||||||
|
|
||||||
# except ImportError:
|
|
||||||
# logger.warning(
|
|
||||||
# "You need to install HuggingFace Hub to load models from the hub."
|
|
||||||
# )
|
|
||||||
# raise RuntimeError(f"Model path {model_path} not found")
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# raise FileNotFoundError(f"Model path {original_model_path} not found")
|
|
||||||
|
|
||||||
# config_path = os.path.join(model_path, config_name)
|
|
||||||
# ckpt_path = os.path.join(model_path, ckpt_name)
|
|
||||||
# return cls.from_single_file(
|
|
||||||
# ckpt_path,
|
|
||||||
# config_path,
|
|
||||||
# device=device,
|
|
||||||
# dtype=dtype,
|
|
||||||
# use_safetensors=use_safetensors,
|
|
||||||
# **kwargs
|
|
||||||
# )
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
#vae,
|
#vae,
|
||||||
|
|||||||
@ -22,8 +22,8 @@
|
|||||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
import custom_rasterizer_kernel
|
|
||||||
import torch
|
import torch
|
||||||
|
import custom_rasterizer_kernel
|
||||||
|
|
||||||
|
|
||||||
def rasterize(pos, tri, resolution, clamp_depth=torch.zeros(0), use_depth_prior=0):
|
def rasterize(pos, tri, resolution, clamp_depth=torch.zeros(0), use_depth_prior=0):
|
||||||
|
|||||||
81
hy3dshape/LICENSE
Normal file
81
hy3dshape/LICENSE
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
TENCENT HUNYUAN 3D 2.1 COMMUNITY LICENSE AGREEMENT
|
||||||
|
Tencent Hunyuan 3D 2.1 Release Date: June 13, 2025
|
||||||
|
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
||||||
|
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan 3D 2.1 Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
||||||
|
1. DEFINITIONS.
|
||||||
|
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
||||||
|
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan 3D 2.1 Works or any portion or element thereof set forth herein.
|
||||||
|
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan 3D 2.1 made publicly available by Tencent.
|
||||||
|
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
||||||
|
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan 3D 2.1 Works for any purpose and in any field of use.
|
||||||
|
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan 3D 2.1 and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
||||||
|
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; (ii) works based on Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan 3D 2.1 or any Model Derivative of Tencent Hunyuan 3D 2.1, to that model in order to cause that model to perform similarly to Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan 3D 2.1 or a Model Derivative of Tencent Hunyuan 3D 2.1 for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
||||||
|
h. “Output” shall mean the information and/or content output of Tencent Hunyuan 3D 2.1 or a Model Derivative that results from operating or otherwise using Tencent Hunyuan 3D 2.1 or a Model Derivative, including via a Hosted Service.
|
||||||
|
i. “Tencent,” “We” or “Us” shall mean THL Q Limited.
|
||||||
|
j. “Tencent Hunyuan 3D 2.1” shall mean the 3D generation models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at [ https://github.com/Tencent-Hunyuan/Hunyuan3D-2.1].
|
||||||
|
k. “Tencent Hunyuan 3D 2.1 Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
||||||
|
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
||||||
|
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
||||||
|
n. “including” shall mean including but not limited to.
|
||||||
|
2. GRANT OF RIGHTS.
|
||||||
|
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
||||||
|
3. DISTRIBUTION.
|
||||||
|
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan 3D 2.1 Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
||||||
|
a. You must provide all such Third Party recipients of the Tencent Hunyuan 3D 2.1 Works or products or services using them a copy of this Agreement;
|
||||||
|
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||||
|
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan 3D 2.1 Works; and (ii) mark the products or services developed by using the Tencent Hunyuan 3D 2.1 Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
||||||
|
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan 3D 2.1 is licensed under the Tencent Hunyuan 3D 2.1 Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
||||||
|
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan 3D 2.1 Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
||||||
|
4. ADDITIONAL COMMERCIAL TERMS.
|
||||||
|
If, on the Tencent Hunyuan 3D 2.1 version release date, the monthly active users of all products or services made available by or for Licensee is greater than 1 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
||||||
|
Subject to Tencent's written approval, you may request a license for the use of Tencent Hunyuan 3D 2.1 by submitting the following information to hunyuan3d@tencent.com:
|
||||||
|
a. Your company’s name and associated business sector that plans to use Tencent Hunyuan 3D 2.1.
|
||||||
|
b. Your intended use case and the purpose of using Tencent Hunyuan 3D 2.1.
|
||||||
|
c. Your plans to modify Tencent Hunyuan 3D 2.1 or create Model Derivatives.
|
||||||
|
5. RULES OF USE.
|
||||||
|
a. Your use of the Tencent Hunyuan 3D 2.1 Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan 3D 2.1 Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan 3D 2.1 Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan 3D 2.1 Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
||||||
|
b. You must not use the Tencent Hunyuan 3D 2.1 Works or any Output or results of the Tencent Hunyuan 3D 2.1 Works to improve any other AI model (other than Tencent Hunyuan 3D 2.1 or Model Derivatives thereof).
|
||||||
|
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan 3D 2.1 Works, Output or results of the Tencent Hunyuan 3D 2.1 Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
||||||
|
6. INTELLECTUAL PROPERTY.
|
||||||
|
a. Subject to Tencent’s ownership of Tencent Hunyuan 3D 2.1 Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
||||||
|
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan 3D 2.1 Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan 3D 2.1 Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
||||||
|
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan 3D 2.1 Works.
|
||||||
|
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
||||||
|
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
||||||
|
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan 3D 2.1 Works or to grant any license thereto.
|
||||||
|
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN 3D 2.1 WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
||||||
|
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN 3D 2.1 WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
8. SURVIVAL AND TERMINATION.
|
||||||
|
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
||||||
|
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan 3D 2.1 Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
||||||
|
9. GOVERNING LAW AND JURISDICTION.
|
||||||
|
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
||||||
|
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
||||||
|
|
||||||
|
EXHIBIT A
|
||||||
|
ACCEPTABLE USE POLICY
|
||||||
|
|
||||||
|
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
||||||
|
Last modified: November 5, 2024
|
||||||
|
|
||||||
|
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan 3D 2.1. You agree not to use Tencent Hunyuan 3D 2.1 or Model Derivatives:
|
||||||
|
1. Outside the Territory;
|
||||||
|
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
||||||
|
3. To harm Yourself or others;
|
||||||
|
4. To repurpose or distribute output from Tencent Hunyuan 3D 2.1 or any Model Derivatives to harm Yourself or others;
|
||||||
|
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
||||||
|
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||||
|
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
||||||
|
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
||||||
|
9. To intentionally defame, disparage or otherwise harass others;
|
||||||
|
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
||||||
|
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
||||||
|
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
||||||
|
13. To impersonate another individual without consent, authorization, or legal right;
|
||||||
|
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
||||||
|
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
||||||
|
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
||||||
|
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
||||||
|
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||||
|
19. For military purposes;
|
||||||
|
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
||||||
@ -0,0 +1,174 @@
|
|||||||
|
name: "DiT: Flux large flowmatching; VAE: 1024 token length; ImageEncoder: DINO Giant; ImageSize: 518"
|
||||||
|
|
||||||
|
training:
|
||||||
|
steps: 10_0000_0000
|
||||||
|
use_amp: true
|
||||||
|
amp_type: "bf16"
|
||||||
|
base_lr: 1.e-5
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
gradient_clip_algorithm: "norm"
|
||||||
|
every_n_train_steps: 2000 # 5000
|
||||||
|
val_check_interval: 50 # 4096
|
||||||
|
limit_val_batches: 16
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
target: hy3dshape.data.dit_asl.AlignedShapeLatentModule
|
||||||
|
params:
|
||||||
|
#! Base setting
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 8
|
||||||
|
val_num_workers: 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
train_data_list: tools/mini_trainset/preprocessed
|
||||||
|
val_data_list: tools/mini_trainset/preprocessed
|
||||||
|
|
||||||
|
#! Image loading
|
||||||
|
cond_stage_key: "image" # image / text / image_text
|
||||||
|
image_size: 518
|
||||||
|
mean: &mean [0.5, 0.5, 0.5]
|
||||||
|
std: &std [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
#! Point cloud sampling
|
||||||
|
pc_size: &pc_size 30720
|
||||||
|
pc_sharpedge_size: &pc_sharpedge_size 30720
|
||||||
|
sharpedge_label: &sharpedge_label true
|
||||||
|
return_normal: true
|
||||||
|
|
||||||
|
#! Augmentation
|
||||||
|
padding: true
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: hy3dshape.models.diffusion.flow_matching_sit.Diffuser
|
||||||
|
params:
|
||||||
|
first_stage_key: "surface"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
scale_by_std: false
|
||||||
|
z_scale_factor: &z_scale_factor 0.9990943042622529 # 1 / 1.0009065167661184
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
|
# ema_config:
|
||||||
|
# ema_model: LitEma
|
||||||
|
# ema_decay: 0.999
|
||||||
|
# ema_inference: false
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
from_pretrained: tencent/Hunyuan3D-2.1
|
||||||
|
params:
|
||||||
|
num_latents: &num_latents 512
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
point_feats: 4
|
||||||
|
num_decoder_layers: 16
|
||||||
|
pc_size: *pc_size
|
||||||
|
pc_sharpedge_size: *pc_sharpedge_size
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: *z_scale_factor
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino giant
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1536
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 24
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 40
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: true
|
||||||
|
image_size: 518
|
||||||
|
|
||||||
|
denoiser_cfg:
|
||||||
|
target: hy3dshape.models.denoisers.hunyuan3ddit.Hunyuan3DDiT
|
||||||
|
params:
|
||||||
|
ckpt_path: ~/.cache/hy3dgen/tencent/Hunyuan3D-2-1-Shape/dit/model.fp16.ckpt
|
||||||
|
input_size: *num_latents
|
||||||
|
context_in_dim: 1536
|
||||||
|
hidden_size: 1024
|
||||||
|
mlp_ratio: 4.0
|
||||||
|
num_heads: 16
|
||||||
|
depth: 16
|
||||||
|
depth_single_blocks: 32
|
||||||
|
axes_dim: [64]
|
||||||
|
theta: 10000
|
||||||
|
qkv_bias: true
|
||||||
|
use_pe: false
|
||||||
|
force_norm_fp32: true
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
transport:
|
||||||
|
target: hy3dshape.models.diffusion.transport.create_transport
|
||||||
|
params:
|
||||||
|
path_type: Linear
|
||||||
|
prediction: velocity
|
||||||
|
sampler:
|
||||||
|
target: hy3dshape.models.diffusion.transport.Sampler
|
||||||
|
params: {}
|
||||||
|
ode_params:
|
||||||
|
sampling_method: euler # dopri5 ...
|
||||||
|
num_steps: &num_steps 50
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
optimizer:
|
||||||
|
target: torch.optim.AdamW
|
||||||
|
params:
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
eps: 1.e-6
|
||||||
|
weight_decay: 1.e-2
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: hy3dshape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: 50 # 5000
|
||||||
|
f_start: 1.e-6
|
||||||
|
f_min: 1.e-3
|
||||||
|
f_max: 1.0
|
||||||
|
|
||||||
|
pipeline_cfg:
|
||||||
|
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
image_processor_cfg:
|
||||||
|
target: hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params: {}
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
logger:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 100 # 10000
|
||||||
|
num_samples: 1
|
||||||
|
sample_times: 1
|
||||||
|
mean: *mean
|
||||||
|
std: *std
|
||||||
|
bounds: [-1.01, -1.01, -1.01, 1.01, 1.01, 1.01]
|
||||||
|
octree_depth: 8
|
||||||
|
num_chunks: 50000
|
||||||
|
mc_level: 0.0
|
||||||
|
|
||||||
|
file_loggers:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalFixASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 50 # 5000
|
||||||
|
test_data_path: "tools/mini_testset/images.json"
|
||||||
@ -0,0 +1,173 @@
|
|||||||
|
name: "DiT: Flux large flowmatching; VAE: 1024 token length; ImageEncoder: DINO Giant; ImageSize: 518"
|
||||||
|
|
||||||
|
training:
|
||||||
|
steps: 10_0000_0000
|
||||||
|
use_amp: true
|
||||||
|
amp_type: "bf16"
|
||||||
|
base_lr: 1e-4
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
gradient_clip_algorithm: "norm"
|
||||||
|
every_n_train_steps: 2000 # 5000
|
||||||
|
val_check_interval: 50 # 4096
|
||||||
|
limit_val_batches: 16
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
target: hy3dshape.data.dit_asl.AlignedShapeLatentModule
|
||||||
|
params:
|
||||||
|
#! Base setting
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 8
|
||||||
|
val_num_workers: 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
train_data_list: tools/mini_trainset/preprocessed
|
||||||
|
val_data_list: tools/mini_trainset/preprocessed
|
||||||
|
|
||||||
|
#! Image loading
|
||||||
|
cond_stage_key: "image" # image / text / image_text
|
||||||
|
image_size: 518
|
||||||
|
mean: &mean [0.5, 0.5, 0.5]
|
||||||
|
std: &std [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
#! Point cloud sampling
|
||||||
|
pc_size: &pc_size 10240
|
||||||
|
pc_sharpedge_size: &pc_sharpedge_size 10240
|
||||||
|
sharpedge_label: &sharpedge_label true
|
||||||
|
return_normal: true
|
||||||
|
|
||||||
|
#! Augmentation
|
||||||
|
padding: true
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: hy3dshape.models.diffusion.flow_matching_sit.Diffuser
|
||||||
|
params:
|
||||||
|
first_stage_key: "surface"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
scale_by_std: false
|
||||||
|
z_scale_factor: &z_scale_factor 0.9990943042622529 # 1 / 1.0009065167661184
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
|
# ema_config:
|
||||||
|
# ema_model: LitEma
|
||||||
|
# ema_decay: 0.999
|
||||||
|
# ema_inference: false
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
from_pretrained: tencent/Hunyuan3D-2.1
|
||||||
|
params:
|
||||||
|
num_latents: &num_latents 512
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
point_feats: 4
|
||||||
|
num_decoder_layers: 16
|
||||||
|
pc_size: *pc_size
|
||||||
|
pc_sharpedge_size: *pc_sharpedge_size
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: *z_scale_factor
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino giant
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1536
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 24
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 40
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: true
|
||||||
|
image_size: 518
|
||||||
|
|
||||||
|
denoiser_cfg:
|
||||||
|
target: hy3dshape.models.denoisers.hunyuan3ddit.Hunyuan3DDiT
|
||||||
|
params:
|
||||||
|
input_size: *num_latents
|
||||||
|
context_in_dim: 1536
|
||||||
|
hidden_size: 1024
|
||||||
|
mlp_ratio: 4.0
|
||||||
|
num_heads: 16
|
||||||
|
depth: 8
|
||||||
|
depth_single_blocks: 16
|
||||||
|
axes_dim: [64]
|
||||||
|
theta: 10000
|
||||||
|
qkv_bias: true
|
||||||
|
use_pe: false
|
||||||
|
force_norm_fp32: true
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
transport:
|
||||||
|
target: hy3dshape.models.diffusion.transport.create_transport
|
||||||
|
params:
|
||||||
|
path_type: Linear
|
||||||
|
prediction: velocity
|
||||||
|
sampler:
|
||||||
|
target: hy3dshape.models.diffusion.transport.Sampler
|
||||||
|
params: {}
|
||||||
|
ode_params:
|
||||||
|
sampling_method: euler # dopri5 ...
|
||||||
|
num_steps: &num_steps 50
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
optimizer:
|
||||||
|
target: torch.optim.AdamW
|
||||||
|
params:
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
eps: 1.e-6
|
||||||
|
weight_decay: 1.e-2
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: hy3dshape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: 50 # 5000
|
||||||
|
f_start: 1.e-6
|
||||||
|
f_min: 1.e-3
|
||||||
|
f_max: 1.0
|
||||||
|
|
||||||
|
pipeline_cfg:
|
||||||
|
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
image_processor_cfg:
|
||||||
|
target: hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params: {}
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
logger:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 100 # 10000
|
||||||
|
num_samples: 1
|
||||||
|
sample_times: 1
|
||||||
|
mean: *mean
|
||||||
|
std: *std
|
||||||
|
bounds: [-1.01, -1.01, -1.01, 1.01, 1.01, 1.01]
|
||||||
|
octree_depth: 8
|
||||||
|
num_chunks: 50000
|
||||||
|
mc_level: 0.0
|
||||||
|
|
||||||
|
file_loggers:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalFixASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 50 # 5000
|
||||||
|
test_data_path: "tools/mini_testset/images.json"
|
||||||
@ -0,0 +1,180 @@
|
|||||||
|
name: "DiT: Flux large flowmatching; VAE: 1024 token length; ImageEncoder: DINO Giant; ImageSize: 518"
|
||||||
|
|
||||||
|
training:
|
||||||
|
steps: 10_0000_0000
|
||||||
|
use_amp: true
|
||||||
|
amp_type: "bf16"
|
||||||
|
base_lr: 1e-5
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
gradient_clip_algorithm: "norm"
|
||||||
|
every_n_train_steps: 2000 # 5000
|
||||||
|
val_check_interval: 50 # 4096
|
||||||
|
limit_val_batches: 16
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
target: hy3dshape.data.dit_asl.AlignedShapeLatentModule
|
||||||
|
params:
|
||||||
|
#! Base setting
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 8
|
||||||
|
val_num_workers: 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
train_data_list: tools/mini_trainset/preprocessed
|
||||||
|
val_data_list: tools/mini_trainset/preprocessed
|
||||||
|
|
||||||
|
#! Image loading
|
||||||
|
cond_stage_key: "image" # image / text / image_text
|
||||||
|
image_size: 518
|
||||||
|
mean: &mean [0.5, 0.5, 0.5]
|
||||||
|
std: &std [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
#! Point cloud sampling
|
||||||
|
pc_size: &pc_size 81920
|
||||||
|
pc_sharpedge_size: &pc_sharpedge_size 0
|
||||||
|
sharpedge_label: &sharpedge_label true
|
||||||
|
return_normal: true
|
||||||
|
|
||||||
|
#! Augmentation
|
||||||
|
padding: true
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: hy3dshape.models.diffusion.flow_matching_sit.Diffuser
|
||||||
|
params:
|
||||||
|
first_stage_key: "surface"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
scale_by_std: false
|
||||||
|
z_scale_factor: &z_scale_factor 1.0039506158752403
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
|
# ema_config:
|
||||||
|
# ema_model: LitEma
|
||||||
|
# ema_decay: 0.999
|
||||||
|
# ema_inference: false
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
from_pretrained: tencent/Hunyuan3D-2.1
|
||||||
|
params:
|
||||||
|
num_latents: &num_latents 4096
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
num_encoder_layers: 8
|
||||||
|
num_decoder_layers: 16
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: *z_scale_factor
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
point_feats: 4
|
||||||
|
pc_size: *pc_size
|
||||||
|
pc_sharpedge_size: *pc_sharpedge_size
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino large
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1024
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 16
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 24
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: false
|
||||||
|
image_size: 518
|
||||||
|
use_cls_token: true
|
||||||
|
|
||||||
|
|
||||||
|
denoiser_cfg:
|
||||||
|
target: hy3dshape.models.denoisers.hunyuandit.HunYuanDiTPlain
|
||||||
|
params:
|
||||||
|
input_size: *num_latents
|
||||||
|
in_channels: 64
|
||||||
|
hidden_size: 2048
|
||||||
|
context_dim: 1024
|
||||||
|
depth: 21
|
||||||
|
num_heads: 16
|
||||||
|
qk_norm: true
|
||||||
|
text_len: 1370
|
||||||
|
with_decoupled_ca: false
|
||||||
|
use_attention_pooling: false
|
||||||
|
qk_norm_type: 'rms'
|
||||||
|
qkv_bias: false
|
||||||
|
use_pos_emb: false
|
||||||
|
num_moe_layers: 6
|
||||||
|
num_experts: 8
|
||||||
|
moe_top_k: 2
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
transport:
|
||||||
|
target: hy3dshape.models.diffusion.transport.create_transport
|
||||||
|
params:
|
||||||
|
path_type: Linear
|
||||||
|
prediction: velocity
|
||||||
|
sampler:
|
||||||
|
target: hy3dshape.models.diffusion.transport.Sampler
|
||||||
|
params: {}
|
||||||
|
ode_params:
|
||||||
|
sampling_method: euler # dopri5 ...
|
||||||
|
num_steps: &num_steps 50
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
optimizer:
|
||||||
|
target: torch.optim.AdamW
|
||||||
|
params:
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
eps: 1.e-6
|
||||||
|
weight_decay: 1.e-2
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: hy3dshape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: 50 # 5000
|
||||||
|
f_start: 1.e-6
|
||||||
|
f_min: 1.e-3
|
||||||
|
f_max: 1.0
|
||||||
|
|
||||||
|
pipeline_cfg:
|
||||||
|
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
image_processor_cfg:
|
||||||
|
target: hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params: {}
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
logger:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 100 # 10000
|
||||||
|
num_samples: 1
|
||||||
|
sample_times: 1
|
||||||
|
mean: *mean
|
||||||
|
std: *std
|
||||||
|
bounds: [-1.01, -1.01, -1.01, 1.01, 1.01, 1.01]
|
||||||
|
octree_depth: 8
|
||||||
|
num_chunks: 50000
|
||||||
|
mc_level: 0.0
|
||||||
|
|
||||||
|
file_loggers:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalFixASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 50 # 5000
|
||||||
|
test_data_path: "tools/mini_testset/images.json"
|
||||||
@ -0,0 +1,180 @@
|
|||||||
|
name: "DiT: Flux large flowmatching; VAE: 1024 token length; ImageEncoder: DINO Giant; ImageSize: 518"
|
||||||
|
|
||||||
|
training:
|
||||||
|
steps: 10_0000_0000
|
||||||
|
use_amp: true
|
||||||
|
amp_type: "bf16"
|
||||||
|
base_lr: 1e-4
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
gradient_clip_algorithm: "norm"
|
||||||
|
every_n_train_steps: 2000 # 5000
|
||||||
|
val_check_interval: 50 # 4096
|
||||||
|
limit_val_batches: 16
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
target: hy3dshape.data.dit_asl.AlignedShapeLatentModule
|
||||||
|
params:
|
||||||
|
#! Base setting
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 8
|
||||||
|
val_num_workers: 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
train_data_list: tools/mini_trainset/preprocessed
|
||||||
|
val_data_list: tools/mini_trainset/preprocessed
|
||||||
|
|
||||||
|
#! Image loading
|
||||||
|
cond_stage_key: "image" # image / text / image_text
|
||||||
|
image_size: 518
|
||||||
|
mean: &mean [0.5, 0.5, 0.5]
|
||||||
|
std: &std [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
#! Point cloud sampling
|
||||||
|
pc_size: &pc_size 81920
|
||||||
|
pc_sharpedge_size: &pc_sharpedge_size 0
|
||||||
|
sharpedge_label: &sharpedge_label true
|
||||||
|
return_normal: true
|
||||||
|
|
||||||
|
#! Augmentation
|
||||||
|
padding: true
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: hy3dshape.models.diffusion.flow_matching_sit.Diffuser
|
||||||
|
params:
|
||||||
|
first_stage_key: "surface"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
scale_by_std: false
|
||||||
|
z_scale_factor: &z_scale_factor 1.0039506158752403
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
|
# ema_config:
|
||||||
|
# ema_model: LitEma
|
||||||
|
# ema_decay: 0.999
|
||||||
|
# ema_inference: false
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
from_pretrained: tencent/Hunyuan3D-2.1
|
||||||
|
params:
|
||||||
|
num_latents: &num_latents 4096
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
num_encoder_layers: 8
|
||||||
|
num_decoder_layers: 16
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: *z_scale_factor
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
point_feats: 4
|
||||||
|
pc_size: *pc_size
|
||||||
|
pc_sharpedge_size: *pc_sharpedge_size
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino large
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1024
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 16
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 24
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: false
|
||||||
|
image_size: 518
|
||||||
|
use_cls_token: true
|
||||||
|
|
||||||
|
|
||||||
|
denoiser_cfg:
|
||||||
|
target: hy3dshape.models.denoisers.hunyuandit.HunYuanDiTPlain
|
||||||
|
params:
|
||||||
|
input_size: *num_latents
|
||||||
|
in_channels: 64
|
||||||
|
hidden_size: 2048
|
||||||
|
context_dim: 1024
|
||||||
|
depth: 11
|
||||||
|
num_heads: 16
|
||||||
|
qk_norm: true
|
||||||
|
text_len: 1370
|
||||||
|
with_decoupled_ca: false
|
||||||
|
use_attention_pooling: false
|
||||||
|
qk_norm_type: 'rms'
|
||||||
|
qkv_bias: false
|
||||||
|
use_pos_emb: false
|
||||||
|
num_moe_layers: 6
|
||||||
|
num_experts: 8
|
||||||
|
moe_top_k: 2
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
transport:
|
||||||
|
target: hy3dshape.models.diffusion.transport.create_transport
|
||||||
|
params:
|
||||||
|
path_type: Linear
|
||||||
|
prediction: velocity
|
||||||
|
sampler:
|
||||||
|
target: hy3dshape.models.diffusion.transport.Sampler
|
||||||
|
params: {}
|
||||||
|
ode_params:
|
||||||
|
sampling_method: euler # dopri5 ...
|
||||||
|
num_steps: &num_steps 50
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
optimizer:
|
||||||
|
target: torch.optim.AdamW
|
||||||
|
params:
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
eps: 1.e-6
|
||||||
|
weight_decay: 1.e-2
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: hy3dshape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: 50 # 5000
|
||||||
|
f_start: 1.e-6
|
||||||
|
f_min: 1.e-3
|
||||||
|
f_max: 1.0
|
||||||
|
|
||||||
|
pipeline_cfg:
|
||||||
|
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
image_processor_cfg:
|
||||||
|
target: hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params: {}
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
logger:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 100 # 10000
|
||||||
|
num_samples: 1
|
||||||
|
sample_times: 1
|
||||||
|
mean: *mean
|
||||||
|
std: *std
|
||||||
|
bounds: [-1.01, -1.01, -1.01, 1.01, 1.01, 1.01]
|
||||||
|
octree_depth: 8
|
||||||
|
num_chunks: 50000
|
||||||
|
mc_level: 0.0
|
||||||
|
|
||||||
|
file_loggers:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalFixASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 50 # 5000
|
||||||
|
test_data_path: "tools/mini_testset/images.json"
|
||||||
@ -0,0 +1,180 @@
|
|||||||
|
name: "DiT: Flux large flowmatching; VAE: 1024 token length; ImageEncoder: DINO Giant; ImageSize: 518"
|
||||||
|
|
||||||
|
training:
|
||||||
|
steps: 10_0000_0000
|
||||||
|
use_amp: true
|
||||||
|
amp_type: "bf16"
|
||||||
|
base_lr: 1e-4
|
||||||
|
gradient_clip_val: 1.0
|
||||||
|
gradient_clip_algorithm: "norm"
|
||||||
|
every_n_train_steps: 2000 # 5000
|
||||||
|
val_check_interval: 50 # 4096
|
||||||
|
limit_val_batches: 16
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
target: hy3dshape.data.dit_asl.AlignedShapeLatentModule
|
||||||
|
params:
|
||||||
|
#! Base setting
|
||||||
|
batch_size: 2
|
||||||
|
num_workers: 8
|
||||||
|
val_num_workers: 4
|
||||||
|
|
||||||
|
# Data
|
||||||
|
train_data_list: tools/mini_trainset/preprocessed
|
||||||
|
val_data_list: tools/mini_trainset/preprocessed
|
||||||
|
|
||||||
|
#! Image loading
|
||||||
|
cond_stage_key: "image" # image / text / image_text
|
||||||
|
image_size: 518
|
||||||
|
mean: &mean [0.5, 0.5, 0.5]
|
||||||
|
std: &std [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
#! Point cloud sampling
|
||||||
|
pc_size: &pc_size 81920
|
||||||
|
pc_sharpedge_size: &pc_sharpedge_size 0
|
||||||
|
sharpedge_label: &sharpedge_label true
|
||||||
|
return_normal: true
|
||||||
|
|
||||||
|
#! Augmentation
|
||||||
|
padding: true
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: hy3dshape.models.diffusion.flow_matching_sit.Diffuser
|
||||||
|
params:
|
||||||
|
first_stage_key: "surface"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
scale_by_std: false
|
||||||
|
z_scale_factor: &z_scale_factor 1.0039506158752403
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
|
# ema_config:
|
||||||
|
# ema_model: LitEma
|
||||||
|
# ema_decay: 0.999
|
||||||
|
# ema_inference: false
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: hy3dshape.models.autoencoders.ShapeVAE
|
||||||
|
from_pretrained: tencent/Hunyuan3D-2.1
|
||||||
|
params:
|
||||||
|
num_latents: &num_latents 512
|
||||||
|
embed_dim: 64
|
||||||
|
num_freqs: 8
|
||||||
|
include_pi: false
|
||||||
|
heads: 16
|
||||||
|
width: 1024
|
||||||
|
num_encoder_layers: 8
|
||||||
|
num_decoder_layers: 16
|
||||||
|
qkv_bias: false
|
||||||
|
qk_norm: true
|
||||||
|
scale_factor: *z_scale_factor
|
||||||
|
geo_decoder_mlp_expand_ratio: 4
|
||||||
|
geo_decoder_downsample_ratio: 1
|
||||||
|
geo_decoder_ln_post: true
|
||||||
|
point_feats: 4
|
||||||
|
pc_size: *pc_size
|
||||||
|
pc_sharpedge_size: *pc_sharpedge_size
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: hy3dshape.models.conditioner.SingleImageEncoder
|
||||||
|
params:
|
||||||
|
main_image_encoder:
|
||||||
|
type: DinoImageEncoder # dino large
|
||||||
|
kwargs:
|
||||||
|
config:
|
||||||
|
attention_probs_dropout_prob: 0.0
|
||||||
|
drop_path_rate: 0.0
|
||||||
|
hidden_act: gelu
|
||||||
|
hidden_dropout_prob: 0.0
|
||||||
|
hidden_size: 1024
|
||||||
|
image_size: 518
|
||||||
|
initializer_range: 0.02
|
||||||
|
layer_norm_eps: 1.e-6
|
||||||
|
layerscale_value: 1.0
|
||||||
|
mlp_ratio: 4
|
||||||
|
model_type: dinov2
|
||||||
|
num_attention_heads: 16
|
||||||
|
num_channels: 3
|
||||||
|
num_hidden_layers: 24
|
||||||
|
patch_size: 14
|
||||||
|
qkv_bias: true
|
||||||
|
torch_dtype: float32
|
||||||
|
use_swiglu_ffn: false
|
||||||
|
image_size: 518
|
||||||
|
use_cls_token: true
|
||||||
|
|
||||||
|
|
||||||
|
denoiser_cfg:
|
||||||
|
target: hy3dshape.models.denoisers.hunyuandit.HunYuanDiTPlain
|
||||||
|
params:
|
||||||
|
input_size: *num_latents
|
||||||
|
in_channels: 64
|
||||||
|
hidden_size: 768
|
||||||
|
context_dim: 1024
|
||||||
|
depth: 6
|
||||||
|
num_heads: 12
|
||||||
|
qk_norm: true
|
||||||
|
text_len: 1370
|
||||||
|
with_decoupled_ca: false
|
||||||
|
use_attention_pooling: false
|
||||||
|
qk_norm_type: 'rms'
|
||||||
|
qkv_bias: false
|
||||||
|
use_pos_emb: false
|
||||||
|
num_moe_layers: 3
|
||||||
|
num_experts: 4
|
||||||
|
moe_top_k: 2
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
transport:
|
||||||
|
target: hy3dshape.models.diffusion.transport.create_transport
|
||||||
|
params:
|
||||||
|
path_type: Linear
|
||||||
|
prediction: velocity
|
||||||
|
sampler:
|
||||||
|
target: hy3dshape.models.diffusion.transport.Sampler
|
||||||
|
params: {}
|
||||||
|
ode_params:
|
||||||
|
sampling_method: euler # dopri5 ...
|
||||||
|
num_steps: &num_steps 50
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
optimizer:
|
||||||
|
target: torch.optim.AdamW
|
||||||
|
params:
|
||||||
|
betas: [0.9, 0.99]
|
||||||
|
eps: 1.e-6
|
||||||
|
weight_decay: 1.e-2
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
target: hy3dshape.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: 50 # 5000
|
||||||
|
f_start: 1.e-6
|
||||||
|
f_min: 1.e-3
|
||||||
|
f_max: 1.0
|
||||||
|
|
||||||
|
pipeline_cfg:
|
||||||
|
target: hy3dshape.pipelines.Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
image_processor_cfg:
|
||||||
|
target: hy3dshape.preprocessors.ImageProcessorV2
|
||||||
|
params: {}
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
logger:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 100 # 10000
|
||||||
|
num_samples: 1
|
||||||
|
sample_times: 1
|
||||||
|
mean: *mean
|
||||||
|
std: *std
|
||||||
|
bounds: [-1.01, -1.01, -1.01, 1.01, 1.01, 1.01]
|
||||||
|
octree_depth: 8
|
||||||
|
num_chunks: 50000
|
||||||
|
mc_level: 0.0
|
||||||
|
|
||||||
|
file_loggers:
|
||||||
|
target: hy3dshape.utils.trainings.mesh_log_callback.ImageConditionalFixASLDiffuserLogger
|
||||||
|
params:
|
||||||
|
step_frequency: 50 # 5000
|
||||||
|
test_data_path: "tools/mini_testset/images.json"
|
||||||
17
hy3dshape/hy3dshape/__init__.py
Normal file
17
hy3dshape/hy3dshape/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
|
||||||
|
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
|
||||||
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
384
hy3dshape/hy3dshape/data/dit_asl.py
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import traceback
|
||||||
|
from typing import Optional, Union, List, Tuple, Dict
|
||||||
|
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from pytorch_lightning import LightningDataModule
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
|
||||||
|
from .utils import worker_init_fn, pytorch_worker_seed, make_seed
|
||||||
|
|
||||||
|
|
||||||
|
class ResampledShards(torch.utils.data.dataset.IterableDataset):
|
||||||
|
def __init__(self, datalist, nshards=sys.maxsize, worker_seed=None, deterministic=False):
|
||||||
|
super().__init__()
|
||||||
|
self.datalist = datalist
|
||||||
|
self.nshards = nshards
|
||||||
|
# If no worker_seed provided, use pytorch_worker_seed function; else use given seed
|
||||||
|
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.epoch = -1
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.epoch += 1
|
||||||
|
if self.deterministic:
|
||||||
|
seed = make_seed(self.worker_seed(), self.epoch)
|
||||||
|
else:
|
||||||
|
seed = make_seed(self.worker_seed(), self.epoch,
|
||||||
|
os.getpid(), time.time_ns(), os.urandom(4))
|
||||||
|
self.rng = random.Random(seed)
|
||||||
|
for _ in range(self.nshards):
|
||||||
|
index = self.rng.randint(0, len(self.datalist) - 1)
|
||||||
|
yield self.datalist[index]
|
||||||
|
|
||||||
|
|
||||||
|
def read_npz(data):
|
||||||
|
# Load a numpy .npz file from a file path or file-like object
|
||||||
|
# The commented line shows how to load from bytes in memory
|
||||||
|
# return np.load(io.BytesIO(data))
|
||||||
|
return np.load(data)
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(path):
|
||||||
|
# Read and parse a JSON file from the given file path
|
||||||
|
with open(path, 'r', encoding='utf-8') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
|
||||||
|
"""
|
||||||
|
Pad the input image and mask to a square shape with padding ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): Input image array of shape (H, W, C).
|
||||||
|
mask (np.ndarray): Corresponding mask array of shape (H, W).
|
||||||
|
center (bool): Whether to center the original image in the padded output.
|
||||||
|
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
|
||||||
|
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
|
||||||
|
"""
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
max_side = max(h, w)
|
||||||
|
|
||||||
|
# Select padding ratio either fixed or randomly within the given range
|
||||||
|
if padding_ratio_range[0] == padding_ratio_range[1]:
|
||||||
|
padding_ratio = padding_ratio_range[0]
|
||||||
|
else:
|
||||||
|
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
|
||||||
|
resize_side = int(max_side * padding_ratio)
|
||||||
|
# resize_side = int(max_side * 1.15)
|
||||||
|
|
||||||
|
pad_h = resize_side - h
|
||||||
|
pad_w = resize_side - w
|
||||||
|
if center:
|
||||||
|
start_h = pad_h // 2
|
||||||
|
else:
|
||||||
|
start_h = pad_h - resize_side // 20
|
||||||
|
|
||||||
|
start_w = pad_w // 2
|
||||||
|
|
||||||
|
# Create new white image and black mask with padded size
|
||||||
|
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
|
||||||
|
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Place original image and mask into the padded canvas
|
||||||
|
newimg[start_h:start_h + h, start_w:start_w + w] = image
|
||||||
|
newmask[start_h:start_h + h, start_w:start_w + w] = mask
|
||||||
|
|
||||||
|
return newimg, newmask
|
||||||
|
|
||||||
|
|
||||||
|
def viz_pc(surface, normal, image_input, name):
|
||||||
|
image_input = image_input.cpu().numpy()
|
||||||
|
image_input = image_input.transpose(1, 2, 0) * 0.5 + 0.5
|
||||||
|
image_input = (image_input * 255).astype(np.uint8)
|
||||||
|
cv2.imwrite(name + '.png', cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
|
||||||
|
surface = surface.cpu().numpy()
|
||||||
|
normal = normal.cpu().numpy()
|
||||||
|
surface_mesh = trimesh.Trimesh(surface, vertex_colors=(normal + 1) / 2)
|
||||||
|
surface_mesh.export(name + '.obj')
|
||||||
|
|
||||||
|
|
||||||
|
class AlignedShapeLatentDataset(torch.utils.data.dataset.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_list: str = None,
|
||||||
|
cond_stage_key: str = "image",
|
||||||
|
image_transform = None,
|
||||||
|
pc_size: int = 2048,
|
||||||
|
pc_sharpedge_size: int = 2048,
|
||||||
|
sharpedge_label: bool = False,
|
||||||
|
return_normal: bool = False,
|
||||||
|
deterministic = False,
|
||||||
|
worker_seed = None,
|
||||||
|
padding = True,
|
||||||
|
padding_ratio_range=[1.15, 1.15]
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(data_list, str) and data_list.endswith('.json'):
|
||||||
|
self.data_list = read_json(data_list_json)
|
||||||
|
elif isinstance(data_list, str) and os.path.isdir(data_list):
|
||||||
|
self.data_list = glob.glob(data_list + '/*')
|
||||||
|
else:
|
||||||
|
self.data_list = data_list
|
||||||
|
assert isinstance(self.data_list, list)
|
||||||
|
self.rng = random.Random(0)
|
||||||
|
|
||||||
|
self.cond_stage_key = cond_stage_key
|
||||||
|
self.image_transform = image_transform
|
||||||
|
|
||||||
|
self.pc_size = pc_size
|
||||||
|
self.pc_sharpedge_size = pc_sharpedge_size
|
||||||
|
self.sharpedge_label = sharpedge_label
|
||||||
|
self.return_normal = return_normal
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_ratio_range = padding_ratio_range
|
||||||
|
|
||||||
|
rank_zero_info(f'*' * 50)
|
||||||
|
rank_zero_info(f'Dataset Infos:')
|
||||||
|
rank_zero_info(f'# of 3D file: {len(self.data_list)}')
|
||||||
|
rank_zero_info(f'# of Surface Points: {self.pc_size}')
|
||||||
|
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
|
||||||
|
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
|
||||||
|
rank_zero_info(f'*' * 50)
|
||||||
|
|
||||||
|
|
||||||
|
def load_surface_sdf_points(self, rng, random_surface, sharpedge_surface):
|
||||||
|
surface_normal = []
|
||||||
|
if self.pc_size > 0:
|
||||||
|
ind = rng.choice(random_surface.shape[0], self.pc_size, replace=False)
|
||||||
|
random_surface = random_surface[ind]
|
||||||
|
if self.sharpedge_label:
|
||||||
|
sharpedge_label = np.zeros((self.pc_size, 1))
|
||||||
|
random_surface = np.concatenate((random_surface, sharpedge_label), axis=1)
|
||||||
|
surface_normal.append(random_surface)
|
||||||
|
|
||||||
|
if self.pc_sharpedge_size > 0:
|
||||||
|
ind_sharpedge = rng.choice(sharpedge_surface.shape[0], self.pc_sharpedge_size, replace=False)
|
||||||
|
sharpedge_surface = sharpedge_surface[ind_sharpedge]
|
||||||
|
if self.sharpedge_label:
|
||||||
|
sharpedge_label = np.ones((self.pc_sharpedge_size, 1))
|
||||||
|
sharpedge_surface = np.concatenate((sharpedge_surface, sharpedge_label), axis=1)
|
||||||
|
surface_normal.append(sharpedge_surface)
|
||||||
|
|
||||||
|
surface_normal = np.concatenate(surface_normal, axis=0)
|
||||||
|
surface_normal = torch.FloatTensor(surface_normal)
|
||||||
|
surface = surface_normal[:, 0:3]
|
||||||
|
normal = surface_normal[:, 3:6]
|
||||||
|
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
|
||||||
|
|
||||||
|
geo_points = 0.0
|
||||||
|
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
|
||||||
|
if self.return_normal:
|
||||||
|
surface = torch.cat([surface, normal], dim=-1)
|
||||||
|
if self.sharpedge_label:
|
||||||
|
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
|
||||||
|
return surface, geo_points
|
||||||
|
|
||||||
|
def load_render(self, imgs_path):
|
||||||
|
imgs_choice = self.rng.sample(imgs_path, 1)
|
||||||
|
images, masks = [], []
|
||||||
|
for image_path in imgs_choice:
|
||||||
|
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
assert image.shape[2] == 4
|
||||||
|
alpha = image[:, :, 3:4].astype(np.float32) / 255
|
||||||
|
forground = image[:, :, :3]
|
||||||
|
background = np.ones_like(forground) * 255
|
||||||
|
img_new = forground * alpha + background * (1 - alpha)
|
||||||
|
image = img_new.astype(np.uint8)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
binary = mask > 0.3
|
||||||
|
non_zero_coords = np.argwhere(binary)
|
||||||
|
x_min, y_min = non_zero_coords.min(axis=0)
|
||||||
|
x_max, y_max = non_zero_coords.max(axis=0)
|
||||||
|
image, mask = padding(
|
||||||
|
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||||
|
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
||||||
|
center=True, padding_ratio_range=self.padding_ratio_range)
|
||||||
|
|
||||||
|
if self.image_transform:
|
||||||
|
image = self.image_transform(image)
|
||||||
|
mask = np.stack((mask, mask, mask), axis=-1)
|
||||||
|
mask = self.image_transform(mask)
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
images = torch.cat(images, dim=0)
|
||||||
|
masks = torch.cat(masks, dim=0)[:1, ...]
|
||||||
|
return images, masks
|
||||||
|
|
||||||
|
def decode(self, item):
|
||||||
|
uid = item.split('/')[-1]
|
||||||
|
render_img_paths = [os.path.join(item, f'render_cond/{i:03d}.png') for i in range(24)]
|
||||||
|
# transforms_json_path = os.path.join(item, 'render_cond/transforms.json')
|
||||||
|
surface_npz_path = os.path.join(item, f'geo_data/{uid}_surface.npz')
|
||||||
|
# sdf_npz_path = os.path.join(item, f'geo_data/{uid}_sdf.npz')
|
||||||
|
# watertight_obj_path = os.path.join(item, f'geo_data/{uid}_watertight.obj')
|
||||||
|
sample = {}
|
||||||
|
sample["image"] = render_img_paths
|
||||||
|
surface_data = read_npz(surface_npz_path)
|
||||||
|
sample["random_surface"] = surface_data['random_surface']
|
||||||
|
sample["sharpedge_surface"] = surface_data['sharp_surface']
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def transform(self, sample):
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
random_surface = sample.get("random_surface", 0)
|
||||||
|
sharpedge_surface = sample.get("sharpedge_surface", 0)
|
||||||
|
image_input, mask_input = self.load_render(sample['image'])
|
||||||
|
surface, geo_points = self.load_surface_sdf_points(rng, random_surface, sharpedge_surface)
|
||||||
|
sample = {
|
||||||
|
"surface": surface,
|
||||||
|
"geo_points": geo_points,
|
||||||
|
"image": image_input,
|
||||||
|
"mask": mask_input,
|
||||||
|
}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
total_num = 0
|
||||||
|
failed_num = 0
|
||||||
|
for data in ResampledShards(self.data_list):
|
||||||
|
total_num += 1
|
||||||
|
if total_num % 1000 == 0:
|
||||||
|
print(f"Current failure rate of data loading:")
|
||||||
|
print(f"{failed_num}/{total_num}={failed_num/total_num}")
|
||||||
|
try:
|
||||||
|
sample = self.decode(data)
|
||||||
|
sample = self.transform(sample)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
failed_num += 1
|
||||||
|
continue
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
class AlignedShapeLatentModule(LightningDataModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_workers: int = 4,
|
||||||
|
val_num_workers: int = 2,
|
||||||
|
train_data_list: str = None,
|
||||||
|
val_data_list: str = None,
|
||||||
|
cond_stage_key: str = "all",
|
||||||
|
image_size: int = 224,
|
||||||
|
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
|
||||||
|
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
|
||||||
|
pc_size: int = 2048,
|
||||||
|
pc_sharpedge_size: int = 2048,
|
||||||
|
sharpedge_label: bool = False,
|
||||||
|
return_normal: bool = False,
|
||||||
|
padding = True,
|
||||||
|
padding_ratio_range=[1.15, 1.15]
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.val_num_workers = val_num_workers
|
||||||
|
|
||||||
|
self.train_data_list = train_data_list
|
||||||
|
self.val_data_list = val_data_list
|
||||||
|
|
||||||
|
self.cond_stage_key = cond_stage_key
|
||||||
|
self.image_size = image_size
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.train_image_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Resize(self.image_size),
|
||||||
|
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||||
|
self.val_image_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Resize(self.image_size),
|
||||||
|
transforms.Normalize(mean=self.mean, std=self.std)])
|
||||||
|
|
||||||
|
self.pc_size = pc_size
|
||||||
|
self.pc_sharpedge_size = pc_sharpedge_size
|
||||||
|
self.sharpedge_label = sharpedge_label
|
||||||
|
self.return_normal = return_normal
|
||||||
|
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_ratio_range = padding_ratio_range
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
asl_params = {
|
||||||
|
"data_list": self.train_data_list,
|
||||||
|
"cond_stage_key": self.cond_stage_key,
|
||||||
|
"image_transform": self.train_image_transform,
|
||||||
|
"pc_size": self.pc_size,
|
||||||
|
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||||
|
"sharpedge_label": self.sharpedge_label,
|
||||||
|
"return_normal": self.return_normal,
|
||||||
|
"padding": self.padding,
|
||||||
|
"padding_ratio_range": self.padding_ratio_range
|
||||||
|
}
|
||||||
|
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
asl_params = {
|
||||||
|
"data_list": self.val_data_list,
|
||||||
|
"cond_stage_key": self.cond_stage_key,
|
||||||
|
"image_transform": self.val_image_transform,
|
||||||
|
"pc_size": self.pc_size,
|
||||||
|
"pc_sharpedge_size": self.pc_sharpedge_size,
|
||||||
|
"sharpedge_label": self.sharpedge_label,
|
||||||
|
"return_normal": self.return_normal,
|
||||||
|
"padding": self.padding,
|
||||||
|
"padding_ratio_range": self.padding_ratio_range
|
||||||
|
}
|
||||||
|
dataset = AlignedShapeLatentDataset(**asl_params)
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.val_num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
186
hy3dshape/hy3dshape/data/utils.py
Normal file
186
hy3dshape/hy3dshape/data/utils.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# This file is part of the WebDataset library.
|
||||||
|
# See the LICENSE file for licensing terms (BSD-style).
|
||||||
|
|
||||||
|
|
||||||
|
"""Miscellaneous utility functions."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import itertools as itt
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from typing import Any, Callable, Iterator, Union
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def make_seed(*args):
|
||||||
|
seed = 0
|
||||||
|
for arg in args:
|
||||||
|
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage:
|
||||||
|
def invoke(self, *args, **kw):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x: Any) -> Any:
|
||||||
|
"""Return the argument as is."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def safe_eval(s: str, expr: str = "{}"):
|
||||||
|
"""Evaluate the given expression more safely."""
|
||||||
|
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
||||||
|
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
||||||
|
return eval(expr.format(s))
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_sym(sym: str, modules: list):
|
||||||
|
"""Look up a symbol in a list of modules."""
|
||||||
|
for mname in modules:
|
||||||
|
module = importlib.import_module(mname, package="webdataset")
|
||||||
|
result = getattr(module, sym, None)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def repeatedly0(
|
||||||
|
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
||||||
|
):
|
||||||
|
"""Repeatedly returns batches from a DataLoader."""
|
||||||
|
for _ in range(nepochs):
|
||||||
|
yield from itt.islice(loader, nbatches)
|
||||||
|
|
||||||
|
|
||||||
|
def guess_batchsize(batch: Union[tuple, list]):
|
||||||
|
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
||||||
|
return len(batch[0])
|
||||||
|
|
||||||
|
|
||||||
|
def repeatedly(
|
||||||
|
source: Iterator,
|
||||||
|
nepochs: int = None,
|
||||||
|
nbatches: int = None,
|
||||||
|
nsamples: int = None,
|
||||||
|
batchsize: Callable[..., int] = guess_batchsize,
|
||||||
|
):
|
||||||
|
"""Repeatedly yield samples from an iterator."""
|
||||||
|
epoch = 0
|
||||||
|
batch = 0
|
||||||
|
total = 0
|
||||||
|
while True:
|
||||||
|
for sample in source:
|
||||||
|
yield sample
|
||||||
|
batch += 1
|
||||||
|
if nbatches is not None and batch >= nbatches:
|
||||||
|
return
|
||||||
|
if nsamples is not None:
|
||||||
|
total += guess_batchsize(sample)
|
||||||
|
if total >= nsamples:
|
||||||
|
return
|
||||||
|
epoch += 1
|
||||||
|
if nepochs is not None and epoch >= nepochs:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
|
||||||
|
"""Return node and worker info for PyTorch and some distributed environments."""
|
||||||
|
rank = 0
|
||||||
|
world_size = 1
|
||||||
|
worker = 0
|
||||||
|
num_workers = 1
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||||
|
group = group or torch.distributed.group.WORLD
|
||||||
|
rank = torch.distributed.get_rank(group=group)
|
||||||
|
world_size = torch.distributed.get_world_size(group=group)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
||||||
|
worker = int(os.environ["WORKER"])
|
||||||
|
num_workers = int(os.environ["NUM_WORKERS"])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
worker = worker_info.id
|
||||||
|
num_workers = worker_info.num_workers
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return rank, world_size, worker, num_workers
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_worker_seed(group=None):
|
||||||
|
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
||||||
|
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
|
||||||
|
return rank * 1000 + worker
|
||||||
|
|
||||||
|
def worker_init_fn(_):
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
worker_id = worker_info.id
|
||||||
|
|
||||||
|
# dataset = worker_info.dataset
|
||||||
|
# split_size = dataset.num_records // worker_info.num_workers
|
||||||
|
# # reset num_records to the true number to retain reliable length information
|
||||||
|
# dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
||||||
|
# current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||||
|
# return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||||
|
|
||||||
|
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
def collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (list[dict]):
|
||||||
|
combine_tensors:
|
||||||
|
combine_scalars:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
keys = samples[0].keys()
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
result[key] = []
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
for key in keys:
|
||||||
|
val = sample[key]
|
||||||
|
result[key].append(val)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
val_list = result[key]
|
||||||
|
if isinstance(val_list[0], (int, float)):
|
||||||
|
if combine_scalars:
|
||||||
|
result[key] = np.array(result[key])
|
||||||
|
|
||||||
|
elif isinstance(val_list[0], torch.Tensor):
|
||||||
|
if combine_tensors:
|
||||||
|
result[key] = torch.stack(val_list)
|
||||||
|
|
||||||
|
elif isinstance(val_list[0], np.ndarray):
|
||||||
|
if combine_tensors:
|
||||||
|
result[key] = np.stack(val_list)
|
||||||
|
|
||||||
|
return result
|
||||||
28
hy3dshape/hy3dshape/models/__init__.py
Normal file
28
hy3dshape/hy3dshape/models/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Open Source Model Licensed under the Apache License Version 2.0
|
||||||
|
# and Other Licenses of the Third-Party Components therein:
|
||||||
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||||
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||||
|
|
||||||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||||
|
# The below software and/or models in this distribution may have been
|
||||||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||||||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
from .autoencoders import ShapeVAE
|
||||||
|
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
||||||
|
from .denoisers import Hunyuan3DDiT
|
||||||
20
hy3dshape/hy3dshape/models/autoencoders/__init__.py
Normal file
20
hy3dshape/hy3dshape/models/autoencoders/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from .attention_blocks import CrossAttentionDecoder
|
||||||
|
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
|
||||||
|
FlashVDMTopMCrossAttentionProcessor
|
||||||
|
from .model import ShapeVAE, VectsetVAE
|
||||||
|
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
|
||||||
|
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
|
||||||
716
hy3dshape/hy3dshape/models/autoencoders/attention_blocks.py
Normal file
716
hy3dshape/hy3dshape/models/autoencoders/attention_blocks.py
Normal file
@ -0,0 +1,716 @@
|
|||||||
|
# Open Source Model Licensed under the Apache License Version 2.0
|
||||||
|
# and Other Licenses of the Third-Party Components therein:
|
||||||
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||||
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||||
|
|
||||||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||||
|
# The below software and/or models in this distribution may have been
|
||||||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||||||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .attention_processors import CrossAttentionProcessor
|
||||||
|
from ...utils import logger
|
||||||
|
|
||||||
|
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
|
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||||
|
scaled_dot_product_attention = sageattn
|
||||||
|
|
||||||
|
|
||||||
|
class FourierEmbedder(nn.Module):
|
||||||
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||||
|
each feature dimension of `x[..., i]` into:
|
||||||
|
[
|
||||||
|
sin(x[..., i]),
|
||||||
|
sin(f_1*x[..., i]),
|
||||||
|
sin(f_2*x[..., i]),
|
||||||
|
...
|
||||||
|
sin(f_N * x[..., i]),
|
||||||
|
cos(x[..., i]),
|
||||||
|
cos(f_1*x[..., i]),
|
||||||
|
cos(f_2*x[..., i]),
|
||||||
|
...
|
||||||
|
cos(f_N * x[..., i]),
|
||||||
|
x[..., i] # only present if include_input is True.
|
||||||
|
], here f_i is the frequency.
|
||||||
|
|
||||||
|
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
||||||
|
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
||||||
|
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_freqs (int): the number of frequencies, default is 6;
|
||||||
|
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||||
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
||||||
|
input_dim (int): the input dimension, default is 3;
|
||||||
|
include_input (bool): include the input tensor or not, default is True.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
||||||
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
||||||
|
|
||||||
|
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
||||||
|
otherwise, it is input_dim * num_freqs * 2.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_freqs: int = 6,
|
||||||
|
logspace: bool = True,
|
||||||
|
input_dim: int = 3,
|
||||||
|
include_input: bool = True,
|
||||||
|
include_pi: bool = True) -> None:
|
||||||
|
|
||||||
|
"""The initialization"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if logspace:
|
||||||
|
frequencies = 2.0 ** torch.arange(
|
||||||
|
num_freqs,
|
||||||
|
dtype=torch.float32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
frequencies = torch.linspace(
|
||||||
|
1.0,
|
||||||
|
2.0 ** (num_freqs - 1),
|
||||||
|
num_freqs,
|
||||||
|
dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_pi:
|
||||||
|
frequencies *= torch.pi
|
||||||
|
|
||||||
|
self.register_buffer("frequencies", frequencies, persistent=False)
|
||||||
|
self.include_input = include_input
|
||||||
|
self.num_freqs = num_freqs
|
||||||
|
|
||||||
|
self.out_dim = self.get_dims(input_dim)
|
||||||
|
|
||||||
|
def get_dims(self, input_dim):
|
||||||
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
||||||
|
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
||||||
|
|
||||||
|
return out_dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
""" Forward process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: tensor of shape [..., dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
||||||
|
where temp is 1 if include_input is True and 0 otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.num_freqs > 0:
|
||||||
|
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
||||||
|
if self.include_input:
|
||||||
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.scale_by_keep = scale_by_keep
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.drop_prob == 0. or not self.training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - self.drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0 and self.scale_by_keep:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, *,
|
||||||
|
width: int,
|
||||||
|
expand_ratio: int = 4,
|
||||||
|
output_width: int = None,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.c_fc = nn.Linear(width, width * expand_ratio)
|
||||||
|
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class QKVMultiheadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
heads: int,
|
||||||
|
n_data: Optional[int] = None,
|
||||||
|
width=None,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.n_data = n_data
|
||||||
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
self.attn_processor = CrossAttentionProcessor()
|
||||||
|
|
||||||
|
def forward(self, q, kv):
|
||||||
|
_, n_ctx, _ = q.shape
|
||||||
|
bs, n_data, width = kv.shape
|
||||||
|
attn_ch = width // self.heads // 2
|
||||||
|
q = q.view(bs, n_ctx, self.heads, -1)
|
||||||
|
kv = kv.view(bs, n_data, self.heads, -1)
|
||||||
|
k, v = torch.split(kv, attn_ch, dim=-1)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
|
out = self.attn_processor(self, q, k, v)
|
||||||
|
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadCrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
n_data: Optional[int] = None,
|
||||||
|
data_width: Optional[int] = None,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
kv_cache: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_data = n_data
|
||||||
|
self.width = width
|
||||||
|
self.heads = heads
|
||||||
|
self.data_width = width if data_width is None else data_width
|
||||||
|
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
||||||
|
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
||||||
|
self.c_proj = nn.Linear(width, width)
|
||||||
|
self.attention = QKVMultiheadCrossAttention(
|
||||||
|
heads=heads,
|
||||||
|
n_data=n_data,
|
||||||
|
width=width,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.kv_cache = kv_cache
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
def forward(self, x, data):
|
||||||
|
x = self.c_q(x)
|
||||||
|
if self.kv_cache:
|
||||||
|
if self.data is None:
|
||||||
|
self.data = self.c_kv(data)
|
||||||
|
logger.info('Save kv cache,this should be called only once for one mesh')
|
||||||
|
data = self.data
|
||||||
|
else:
|
||||||
|
data = self.c_kv(data)
|
||||||
|
x = self.attention(x, data)
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCrossAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_data: Optional[int] = None,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_expand_ratio: int = 4,
|
||||||
|
data_width: Optional[int] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm: bool = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if data_width is None:
|
||||||
|
data_width = width
|
||||||
|
|
||||||
|
self.attn = MultiheadCrossAttention(
|
||||||
|
n_data=n_data,
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
data_width=data_width,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
||||||
|
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
||||||
|
x = x + self.mlp(self.ln_3(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QKVMultiheadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
heads: int,
|
||||||
|
n_ctx: int,
|
||||||
|
width=None,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
bs, n_ctx, width = qkv.shape
|
||||||
|
attn_ch = width // self.heads // 3
|
||||||
|
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
||||||
|
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
||||||
|
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_ctx: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
self.width = width
|
||||||
|
self.heads = heads
|
||||||
|
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
||||||
|
self.c_proj = nn.Linear(width, width)
|
||||||
|
self.attention = QKVMultiheadAttention(
|
||||||
|
heads=heads,
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
width=width,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.c_qkv(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
x = self.drop_path(self.c_proj(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_ctx: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = MultiheadAttention(
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
||||||
|
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x + self.attn(self.ln_1(x))
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_ctx: int,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.resblocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
for _ in range(layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
for block in self.resblocks:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_latents: int,
|
||||||
|
out_channels: int,
|
||||||
|
fourier_embedder: FourierEmbedder,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
mlp_expand_ratio: int = 4,
|
||||||
|
downsample_ratio: int = 1,
|
||||||
|
enable_ln_post: bool = True,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
label_type: str = "binary"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.enable_ln_post = enable_ln_post
|
||||||
|
self.fourier_embedder = fourier_embedder
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
||||||
|
if self.downsample_ratio != 1:
|
||||||
|
self.latents_proj = nn.Linear(width * downsample_ratio, width)
|
||||||
|
if self.enable_ln_post == False:
|
||||||
|
qk_norm = False
|
||||||
|
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
||||||
|
n_data=num_latents,
|
||||||
|
width=width,
|
||||||
|
mlp_expand_ratio=mlp_expand_ratio,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_ln_post:
|
||||||
|
self.ln_post = nn.LayerNorm(width)
|
||||||
|
self.output_proj = nn.Linear(width, out_channels)
|
||||||
|
self.label_type = label_type
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def set_cross_attention_processor(self, processor):
|
||||||
|
self.cross_attn_decoder.attn.attention.attn_processor = processor
|
||||||
|
|
||||||
|
def set_default_cross_attention_processor(self):
|
||||||
|
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
|
||||||
|
|
||||||
|
def forward(self, queries=None, query_embeddings=None, latents=None):
|
||||||
|
if query_embeddings is None:
|
||||||
|
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
||||||
|
self.count += query_embeddings.shape[1]
|
||||||
|
if self.downsample_ratio != 1:
|
||||||
|
latents = self.latents_proj(latents)
|
||||||
|
x = self.cross_attn_decoder(query_embeddings, latents)
|
||||||
|
if self.enable_ln_post:
|
||||||
|
x = self.ln_post(x)
|
||||||
|
occ = self.output_proj(x)
|
||||||
|
return occ
|
||||||
|
|
||||||
|
|
||||||
|
def fps(
|
||||||
|
src: torch.Tensor,
|
||||||
|
batch: Optional[Tensor] = None,
|
||||||
|
ratio: Optional[Union[Tensor, float]] = None,
|
||||||
|
random_start: bool = True,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
ptr: Optional[Union[Tensor, List[int]]] = None,
|
||||||
|
):
|
||||||
|
src = src.float()
|
||||||
|
from torch_cluster import fps as fps_fn
|
||||||
|
output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PointCrossAttentionEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *,
|
||||||
|
num_latents: int,
|
||||||
|
downsample_ratio: float,
|
||||||
|
pc_size: int,
|
||||||
|
pc_sharpedge_size: int,
|
||||||
|
fourier_embedder: FourierEmbedder,
|
||||||
|
point_feats: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
layers: int,
|
||||||
|
normal_pe: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
use_ln_post: bool = False,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
qk_norm: bool = False
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.num_latents = num_latents
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.point_feats = point_feats
|
||||||
|
self.normal_pe = normal_pe
|
||||||
|
|
||||||
|
if pc_sharpedge_size == 0:
|
||||||
|
print(
|
||||||
|
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given, using pc_size as pc_sharpedge_size')
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
|
||||||
|
|
||||||
|
self.pc_size = pc_size
|
||||||
|
self.pc_sharpedge_size = pc_sharpedge_size
|
||||||
|
|
||||||
|
self.fourier_embedder = fourier_embedder
|
||||||
|
|
||||||
|
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||||
|
self.cross_attn = ResidualCrossAttentionBlock(
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attn = None
|
||||||
|
if layers > 0:
|
||||||
|
self.self_attn = Transformer(
|
||||||
|
n_ctx=num_latents,
|
||||||
|
width=width,
|
||||||
|
layers=layers,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ln_post:
|
||||||
|
self.ln_post = nn.LayerNorm(width)
|
||||||
|
else:
|
||||||
|
self.ln_post = None
|
||||||
|
|
||||||
|
def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
||||||
|
B, N, D = pc.shape
|
||||||
|
num_pts = self.num_latents * self.downsample_ratio
|
||||||
|
|
||||||
|
# Compute number of latents
|
||||||
|
num_latents = int(num_pts / self.downsample_ratio)
|
||||||
|
|
||||||
|
# Compute the number of random and sharpedge latents
|
||||||
|
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
||||||
|
num_sharpedge_query = num_latents - num_random_query
|
||||||
|
|
||||||
|
# Split random and sharpedge surface points
|
||||||
|
random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||||
|
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
||||||
|
assert sharpedge_pc.shape[
|
||||||
|
1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
||||||
|
|
||||||
|
# Randomly select random surface points and random query points
|
||||||
|
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||||
|
random_query_ratio = num_random_query / input_random_pc_size
|
||||||
|
idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]
|
||||||
|
input_random_pc = random_pc[:, idx_random_pc, :]
|
||||||
|
flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
|
||||||
|
N_down = int(flatten_input_random_pc.shape[0] / B)
|
||||||
|
batch_down = torch.arange(B).to(pc.device)
|
||||||
|
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||||
|
idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)
|
||||||
|
query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
|
||||||
|
|
||||||
|
# Randomly select sharpedge surface points and sharpedge query points
|
||||||
|
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)
|
||||||
|
query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)
|
||||||
|
else:
|
||||||
|
sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
|
||||||
|
idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[
|
||||||
|
:input_sharpedge_pc_size]
|
||||||
|
input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
|
||||||
|
flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)
|
||||||
|
N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
|
||||||
|
batch_down = torch.arange(B).to(pc.device)
|
||||||
|
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||||
|
idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)
|
||||||
|
query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)
|
||||||
|
|
||||||
|
# Concatenate random and sharpedge surface points and query points
|
||||||
|
query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
|
||||||
|
input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
|
||||||
|
|
||||||
|
# PE
|
||||||
|
query = self.fourier_embedder(query_pc)
|
||||||
|
data = self.fourier_embedder(input_pc)
|
||||||
|
|
||||||
|
# Concat normal if given
|
||||||
|
if self.point_feats != 0:
|
||||||
|
|
||||||
|
random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],
|
||||||
|
dim=1)
|
||||||
|
input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
|
||||||
|
flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)
|
||||||
|
query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,
|
||||||
|
flatten_input_random_surface_feats.shape[
|
||||||
|
-1])
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,
|
||||||
|
dtype=input_random_surface_feats.dtype).to(pc.device)
|
||||||
|
query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(
|
||||||
|
pc.device)
|
||||||
|
else:
|
||||||
|
input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]
|
||||||
|
flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,
|
||||||
|
-1)
|
||||||
|
query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,
|
||||||
|
flatten_input_sharpedge_surface_feats.shape[
|
||||||
|
-1])
|
||||||
|
|
||||||
|
query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
|
||||||
|
input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)
|
||||||
|
|
||||||
|
if self.normal_pe:
|
||||||
|
query_normal_pe = self.fourier_embedder(query_feats[..., :3])
|
||||||
|
input_normal_pe = self.fourier_embedder(input_feats[..., :3])
|
||||||
|
query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
|
||||||
|
input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
|
||||||
|
|
||||||
|
query = torch.cat([query, query_feats], dim=-1)
|
||||||
|
data = torch.cat([data, input_feats], dim=-1)
|
||||||
|
|
||||||
|
if input_sharpedge_pc_size == 0:
|
||||||
|
query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
||||||
|
input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
||||||
|
|
||||||
|
# print(f'query_pc: {query_pc.shape}')
|
||||||
|
# print(f'input_pc: {input_pc.shape}')
|
||||||
|
# print(f'query_random_pc: {query_random_pc.shape}')
|
||||||
|
# print(f'input_random_pc: {input_random_pc.shape}')
|
||||||
|
# print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}')
|
||||||
|
# print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}')
|
||||||
|
|
||||||
|
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), [query_pc, input_pc,
|
||||||
|
query_random_pc, input_random_pc,
|
||||||
|
query_sharpedge_pc,
|
||||||
|
input_sharpedge_pc]
|
||||||
|
|
||||||
|
def forward(self, pc, feats):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pc (torch.FloatTensor): [B, N, 3]
|
||||||
|
feats (torch.FloatTensor or None): [B, N, C]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
query, data, pc_infos = self.sample_points_and_latents(pc, feats)
|
||||||
|
|
||||||
|
query = self.input_proj(query)
|
||||||
|
query = query
|
||||||
|
data = self.input_proj(data)
|
||||||
|
data = data
|
||||||
|
|
||||||
|
latents = self.cross_attn(query, data)
|
||||||
|
if self.self_attn is not None:
|
||||||
|
latents = self.self_attn(latents)
|
||||||
|
|
||||||
|
if self.ln_post is not None:
|
||||||
|
latents = self.ln_post(latents)
|
||||||
|
|
||||||
|
return latents, pc_infos
|
||||||
@ -0,0 +1,96 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
scaled_dot_product_attention = F.scaled_dot_product_attention
|
||||||
|
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||||
|
scaled_dot_product_attention = sageattn
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionProcessor:
|
||||||
|
def __call__(self, attn, q, k, v):
|
||||||
|
out = scaled_dot_product_attention(q, k, v)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FlashVDMCrossAttentionProcessor:
|
||||||
|
def __init__(self, topk=None):
|
||||||
|
self.topk = topk
|
||||||
|
|
||||||
|
def __call__(self, attn, q, k, v):
|
||||||
|
if k.shape[-2] == 3072:
|
||||||
|
topk = 1024
|
||||||
|
elif k.shape[-2] == 512:
|
||||||
|
topk = 256
|
||||||
|
else:
|
||||||
|
topk = k.shape[-2] // 3
|
||||||
|
|
||||||
|
if self.topk is True:
|
||||||
|
q1 = q[:, :, ::100, :]
|
||||||
|
sim = q1 @ k.transpose(-1, -2)
|
||||||
|
sim = torch.mean(sim, -2)
|
||||||
|
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||||
|
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||||
|
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||||
|
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||||
|
out = scaled_dot_product_attention(q, k0, v0)
|
||||||
|
elif self.topk is False:
|
||||||
|
out = scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
idx, counts = self.topk
|
||||||
|
start = 0
|
||||||
|
outs = []
|
||||||
|
for grid_coord, count in zip(idx, counts):
|
||||||
|
end = start + count
|
||||||
|
q_chunk = q[:, :, start:end, :]
|
||||||
|
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
|
||||||
|
out = scaled_dot_product_attention(q_chunk, k0, v0)
|
||||||
|
outs.append(out)
|
||||||
|
start += count
|
||||||
|
out = torch.cat(outs, dim=-2)
|
||||||
|
self.topk = False
|
||||||
|
return out
|
||||||
|
|
||||||
|
def select_topkv(self, q_chunk, k, v, topk):
|
||||||
|
q1 = q_chunk[:, :, ::50, :]
|
||||||
|
sim = q1 @ k.transpose(-1, -2)
|
||||||
|
sim = torch.mean(sim, -2)
|
||||||
|
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
||||||
|
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
||||||
|
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
||||||
|
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
||||||
|
return k0, v0
|
||||||
|
|
||||||
|
|
||||||
|
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
|
||||||
|
def select_topkv(self, q_chunk, k, v, topk):
|
||||||
|
q1 = q_chunk[:, :, ::30, :]
|
||||||
|
sim = q1 @ k.transpose(-1, -2)
|
||||||
|
# sim = sim.to(torch.float32)
|
||||||
|
sim = sim.softmax(-1)
|
||||||
|
sim = torch.mean(sim, 1)
|
||||||
|
activated_token = torch.where(sim > 1e-6)[2]
|
||||||
|
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
||||||
|
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
|
||||||
|
v0 = torch.gather(v, dim=-2, index=index)
|
||||||
|
k0 = torch.gather(k, dim=-2, index=index)
|
||||||
|
return k0, v0
|
||||||
339
hy3dshape/hy3dshape/models/autoencoders/model.py
Normal file
339
hy3dshape/hy3dshape/models/autoencoders/model.py
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
# Open Source Model Licensed under the Apache License Version 2.0
|
||||||
|
# and Other Licenses of the Third-Party Components therein:
|
||||||
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||||
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||||
|
|
||||||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||||
|
# The below software and/or models in this distribution may have been
|
||||||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||||||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder
|
||||||
|
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
||||||
|
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
|
||||||
|
from ...utils import logger, synchronize_timer, smart_load_model
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
|
||||||
|
"""
|
||||||
|
Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (Union[torch.Tensor, List[torch.Tensor]]):
|
||||||
|
Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
|
||||||
|
or a list of two tensors [mean, logvar].
|
||||||
|
deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
|
||||||
|
Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
|
||||||
|
concatenated if parameters is a single tensor. Default is 1.
|
||||||
|
"""
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.parameters = parameters
|
||||||
|
|
||||||
|
if isinstance(parameters, list):
|
||||||
|
self.mean = parameters[0]
|
||||||
|
self.logvar = parameters[1]
|
||||||
|
else:
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
||||||
|
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(self.mean)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""
|
||||||
|
Sample from the diagonal Gaussian distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A sample tensor with the same shape as the mean.
|
||||||
|
"""
|
||||||
|
x = self.mean + self.std * torch.randn_like(self.mean)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other=None, dims=(1, 2, 3)):
|
||||||
|
"""
|
||||||
|
Compute the Kullback-Leibler (KL) divergence between this distribution and another.
|
||||||
|
|
||||||
|
If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
|
||||||
|
dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
|
||||||
|
Default is (1, 2, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The mean KL divergence value.
|
||||||
|
"""
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
||||||
|
+ self.var - 1.0 - self.logvar,
|
||||||
|
dim=dims)
|
||||||
|
else:
|
||||||
|
return 0.5 * torch.mean(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=dims)
|
||||||
|
|
||||||
|
def nll(self, sample, dims=(1, 2, 3)):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
class VectsetVAE(nn.Module):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@synchronize_timer('VectsetVAE Model Loading')
|
||||||
|
def from_single_file(
|
||||||
|
cls,
|
||||||
|
ckpt_path,
|
||||||
|
config_path,
|
||||||
|
device='cuda',
|
||||||
|
dtype=torch.float16,
|
||||||
|
use_safetensors=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# load config
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# load ckpt
|
||||||
|
if use_safetensors:
|
||||||
|
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||||
|
if not os.path.exists(ckpt_path):
|
||||||
|
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||||
|
|
||||||
|
logger.info(f"Loading model from {ckpt_path}")
|
||||||
|
if use_safetensors:
|
||||||
|
import safetensors.torch
|
||||||
|
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||||
|
|
||||||
|
model_kwargs = config['params']
|
||||||
|
model_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
model = cls(**model_kwargs)
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
model.to(device=device, dtype=dtype)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_path,
|
||||||
|
device='cuda',
|
||||||
|
dtype=torch.float16,
|
||||||
|
use_safetensors=False,
|
||||||
|
variant='fp16',
|
||||||
|
subfolder='hunyuan3d-vae-v2-1',
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config_path, ckpt_path = smart_load_model(
|
||||||
|
model_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
variant=variant
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_single_file(
|
||||||
|
ckpt_path,
|
||||||
|
config_path,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=()):
|
||||||
|
state_dict = torch.load(path, map_location="cpu")
|
||||||
|
state_dict = state_dict.get("state_dict", state_dict)
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del state_dict[k]
|
||||||
|
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
volume_decoder=None,
|
||||||
|
surface_extractor=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if volume_decoder is None:
|
||||||
|
volume_decoder = VanillaVolumeDecoder()
|
||||||
|
if surface_extractor is None:
|
||||||
|
surface_extractor = MCSurfaceExtractor()
|
||||||
|
self.volume_decoder = volume_decoder
|
||||||
|
self.surface_extractor = surface_extractor
|
||||||
|
|
||||||
|
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
||||||
|
with synchronize_timer('Volume decoding'):
|
||||||
|
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
|
||||||
|
with synchronize_timer('Surface extraction'):
|
||||||
|
outputs = self.surface_extractor(grid_logits, **kwargs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def enable_flashvdm_decoder(
|
||||||
|
self,
|
||||||
|
enabled: bool = True,
|
||||||
|
adaptive_kv_selection=True,
|
||||||
|
topk_mode='mean',
|
||||||
|
mc_algo='dmc',
|
||||||
|
):
|
||||||
|
if enabled:
|
||||||
|
if adaptive_kv_selection:
|
||||||
|
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
|
||||||
|
else:
|
||||||
|
self.volume_decoder = HierarchicalVolumeDecoding()
|
||||||
|
if mc_algo not in SurfaceExtractors.keys():
|
||||||
|
raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')
|
||||||
|
self.surface_extractor = SurfaceExtractors[mc_algo]()
|
||||||
|
else:
|
||||||
|
self.volume_decoder = VanillaVolumeDecoder()
|
||||||
|
self.surface_extractor = MCSurfaceExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeVAE(VectsetVAE):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_latents: int,
|
||||||
|
embed_dim: int,
|
||||||
|
width: int,
|
||||||
|
heads: int,
|
||||||
|
num_decoder_layers: int,
|
||||||
|
num_encoder_layers: int = 8,
|
||||||
|
pc_size: int = 5120,
|
||||||
|
pc_sharpedge_size: int = 5120,
|
||||||
|
point_feats: int = 3,
|
||||||
|
downsample_ratio: int = 20,
|
||||||
|
geo_decoder_downsample_ratio: int = 1,
|
||||||
|
geo_decoder_mlp_expand_ratio: int = 4,
|
||||||
|
geo_decoder_ln_post: bool = True,
|
||||||
|
num_freqs: int = 8,
|
||||||
|
include_pi: bool = True,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
label_type: str = "binary",
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
scale_factor: float = 1.0,
|
||||||
|
use_ln_post: bool = True,
|
||||||
|
ckpt_path = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.geo_decoder_ln_post = geo_decoder_ln_post
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
|
||||||
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||||
|
|
||||||
|
self.encoder = PointCrossAttentionEncoder(
|
||||||
|
fourier_embedder=self.fourier_embedder,
|
||||||
|
num_latents=num_latents,
|
||||||
|
downsample_ratio=self.downsample_ratio,
|
||||||
|
pc_size=pc_size,
|
||||||
|
pc_sharpedge_size=pc_sharpedge_size,
|
||||||
|
point_feats=point_feats,
|
||||||
|
width=width,
|
||||||
|
heads=heads,
|
||||||
|
layers=num_encoder_layers,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_ln_post=use_ln_post,
|
||||||
|
qk_norm=qk_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_kl = nn.Linear(width, embed_dim * 2)
|
||||||
|
self.post_kl = nn.Linear(embed_dim, width)
|
||||||
|
|
||||||
|
self.transformer = Transformer(
|
||||||
|
n_ctx=num_latents,
|
||||||
|
width=width,
|
||||||
|
layers=num_decoder_layers,
|
||||||
|
heads=heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
drop_path_rate=drop_path_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
self.geo_decoder = CrossAttentionDecoder(
|
||||||
|
fourier_embedder=self.fourier_embedder,
|
||||||
|
out_channels=1,
|
||||||
|
num_latents=num_latents,
|
||||||
|
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
||||||
|
downsample_ratio=geo_decoder_downsample_ratio,
|
||||||
|
enable_ln_post=self.geo_decoder_ln_post,
|
||||||
|
width=width // geo_decoder_downsample_ratio,
|
||||||
|
heads=heads // geo_decoder_downsample_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
label_type=label_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.latent_shape = (num_latents, embed_dim)
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
def forward(self, latents):
|
||||||
|
latents = self.post_kl(latents)
|
||||||
|
latents = self.transformer(latents)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def encode(self, surface, sample_posterior=True):
|
||||||
|
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||||
|
latents, _ = self.encoder(pc, feats)
|
||||||
|
# print(latents.shape, self.pre_kl.weight.shape)
|
||||||
|
moments = self.pre_kl(latents)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
||||||
|
if sample_posterior:
|
||||||
|
latents = posterior.sample()
|
||||||
|
else:
|
||||||
|
latents = posterior.mode()
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def decode(self, latents):
|
||||||
|
latents = self.post_kl(latents)
|
||||||
|
latents = self.transformer(latents)
|
||||||
|
return latents
|
||||||
164
hy3dshape/hy3dshape/models/autoencoders/surface_extractors.py
Normal file
164
hy3dshape/hy3dshape/models/autoencoders/surface_extractors.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from typing import Union, Tuple, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from skimage import measure
|
||||||
|
|
||||||
|
|
||||||
|
class Latent2MeshOutput:
|
||||||
|
def __init__(self, mesh_v=None, mesh_f=None):
|
||||||
|
self.mesh_v = mesh_v
|
||||||
|
self.mesh_f = mesh_f
|
||||||
|
|
||||||
|
|
||||||
|
def center_vertices(vertices):
|
||||||
|
"""Translate the vertices so that bounding box is centered at zero."""
|
||||||
|
vert_min = vertices.min(dim=0)[0]
|
||||||
|
vert_max = vertices.max(dim=0)[0]
|
||||||
|
vert_center = 0.5 * (vert_min + vert_max)
|
||||||
|
return vertices - vert_center
|
||||||
|
|
||||||
|
|
||||||
|
class SurfaceExtractor:
|
||||||
|
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
|
||||||
|
"""
|
||||||
|
Compute grid size, bounding box minimum coordinates, and bounding box size based on input
|
||||||
|
bounds and resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
|
||||||
|
float representing half side length.
|
||||||
|
If float, bounds are assumed symmetric around zero in all axes.
|
||||||
|
Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
|
||||||
|
octree_resolution (int): Resolution of the octree grid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
|
||||||
|
bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
|
||||||
|
bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
|
||||||
|
"""
|
||||||
|
if isinstance(bounds, float):
|
||||||
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
|
||||||
|
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||||
|
bbox_size = bbox_max - bbox_min
|
||||||
|
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||||
|
return grid_size, bbox_min, bbox_size
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Abstract method to extract surface mesh from grid logits.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: Always, since this is an abstract method.
|
||||||
|
"""
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, grid_logits, **kwargs):
|
||||||
|
"""
|
||||||
|
Process a batch of grid logits to extract surface meshes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
|
||||||
|
**kwargs: Additional keyword arguments passed to the `run` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
|
||||||
|
If extraction fails for a grid, None is appended at that position.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
for i in range(grid_logits.shape[0]):
|
||||||
|
try:
|
||||||
|
vertices, faces = self.run(grid_logits[i], **kwargs)
|
||||||
|
vertices = vertices.astype(np.float32)
|
||||||
|
faces = np.ascontiguousarray(faces)
|
||||||
|
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
outputs.append(None)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MCSurfaceExtractor(SurfaceExtractor):
|
||||||
|
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
|
||||||
|
"""
|
||||||
|
Extract surface mesh using the Marching Cubes algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
||||||
|
mc_level (float): The level (iso-value) at which to extract the surface.
|
||||||
|
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
|
||||||
|
octree_resolution (int): Resolution of the octree grid.
|
||||||
|
**kwargs: Additional keyword arguments (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
||||||
|
- vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
|
||||||
|
box coordinates.
|
||||||
|
- faces (np.ndarray): Extracted mesh faces (triangles).
|
||||||
|
"""
|
||||||
|
vertices, faces, normals, _ = measure.marching_cubes(grid_logit.cpu().numpy(),
|
||||||
|
mc_level,
|
||||||
|
method="lewiner")
|
||||||
|
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
|
||||||
|
vertices = vertices / grid_size * bbox_size + bbox_min
|
||||||
|
return vertices, faces
|
||||||
|
|
||||||
|
|
||||||
|
class DMCSurfaceExtractor(SurfaceExtractor):
|
||||||
|
def run(self, grid_logit, *, octree_resolution, **kwargs):
|
||||||
|
"""
|
||||||
|
Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
||||||
|
octree_resolution (int): Resolution of the octree grid.
|
||||||
|
**kwargs: Additional keyword arguments (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
||||||
|
- vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
|
||||||
|
- faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If the 'diso' package is not installed.
|
||||||
|
"""
|
||||||
|
device = grid_logit.device
|
||||||
|
if not hasattr(self, 'dmc'):
|
||||||
|
try:
|
||||||
|
from diso import DiffDMC
|
||||||
|
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
||||||
|
except:
|
||||||
|
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
||||||
|
sdf = -grid_logit / octree_resolution
|
||||||
|
sdf = sdf.to(torch.float32).contiguous()
|
||||||
|
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
||||||
|
verts = center_vertices(verts)
|
||||||
|
vertices = verts.detach().cpu().numpy()
|
||||||
|
faces = faces.detach().cpu().numpy()[:, ::-1]
|
||||||
|
return vertices, faces
|
||||||
|
|
||||||
|
|
||||||
|
SurfaceExtractors = {
|
||||||
|
'mc': MCSurfaceExtractor,
|
||||||
|
'dmc': DMCSurfaceExtractor,
|
||||||
|
}
|
||||||
435
hy3dshape/hy3dshape/models/autoencoders/volume_decoders.py
Normal file
435
hy3dshape/hy3dshape/models/autoencoders/volume_decoders.py
Normal file
@ -0,0 +1,435 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from typing import Union, Tuple, List, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import repeat
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .attention_blocks import CrossAttentionDecoder
|
||||||
|
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
||||||
|
from ...utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
||||||
|
device = input_tensor.device
|
||||||
|
D = input_tensor.shape[0]
|
||||||
|
signed_val = 0.0
|
||||||
|
|
||||||
|
# 添加偏移并处理无效值
|
||||||
|
val = input_tensor + alpha
|
||||||
|
valid_mask = val > -9000 # 假设-9000是无效值
|
||||||
|
|
||||||
|
# 改进的邻居获取函数(保持维度一致)
|
||||||
|
def get_neighbor(t, shift, axis):
|
||||||
|
"""根据指定轴进行位移并保持维度一致"""
|
||||||
|
if shift == 0:
|
||||||
|
return t.clone()
|
||||||
|
|
||||||
|
# 确定填充轴(输入为[D, D, D]对应z,y,x轴)
|
||||||
|
pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
|
||||||
|
|
||||||
|
# 根据轴类型设置填充
|
||||||
|
if axis == 0: # x轴(最后一个维度)
|
||||||
|
pad_idx = 0 if shift > 0 else 1
|
||||||
|
pad_dims[pad_idx] = abs(shift)
|
||||||
|
elif axis == 1: # y轴(中间维度)
|
||||||
|
pad_idx = 2 if shift > 0 else 3
|
||||||
|
pad_dims[pad_idx] = abs(shift)
|
||||||
|
elif axis == 2: # z轴(第一个维度)
|
||||||
|
pad_idx = 4 if shift > 0 else 5
|
||||||
|
pad_dims[pad_idx] = abs(shift)
|
||||||
|
|
||||||
|
# 执行填充(添加batch和channel维度适配F.pad)
|
||||||
|
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
|
||||||
|
|
||||||
|
# 构建动态切片索引
|
||||||
|
slice_dims = [slice(None)] * 3 # 初始化为全切片
|
||||||
|
if axis == 0: # x轴(dim=2)
|
||||||
|
if shift > 0:
|
||||||
|
slice_dims[0] = slice(shift, None)
|
||||||
|
else:
|
||||||
|
slice_dims[0] = slice(None, shift)
|
||||||
|
elif axis == 1: # y轴(dim=1)
|
||||||
|
if shift > 0:
|
||||||
|
slice_dims[1] = slice(shift, None)
|
||||||
|
else:
|
||||||
|
slice_dims[1] = slice(None, shift)
|
||||||
|
elif axis == 2: # z轴(dim=0)
|
||||||
|
if shift > 0:
|
||||||
|
slice_dims[2] = slice(shift, None)
|
||||||
|
else:
|
||||||
|
slice_dims[2] = slice(None, shift)
|
||||||
|
|
||||||
|
# 应用切片并恢复维度
|
||||||
|
padded = padded.squeeze(0).squeeze(0)
|
||||||
|
sliced = padded[slice_dims]
|
||||||
|
return sliced
|
||||||
|
|
||||||
|
# 获取各方向邻居(确保维度一致)
|
||||||
|
left = get_neighbor(val, 1, axis=0) # x方向
|
||||||
|
right = get_neighbor(val, -1, axis=0)
|
||||||
|
back = get_neighbor(val, 1, axis=1) # y方向
|
||||||
|
front = get_neighbor(val, -1, axis=1)
|
||||||
|
down = get_neighbor(val, 1, axis=2) # z方向
|
||||||
|
up = get_neighbor(val, -1, axis=2)
|
||||||
|
|
||||||
|
# 处理边界无效值(使用where保持维度一致)
|
||||||
|
def safe_where(neighbor):
|
||||||
|
return torch.where(neighbor > -9000, neighbor, val)
|
||||||
|
|
||||||
|
left = safe_where(left)
|
||||||
|
right = safe_where(right)
|
||||||
|
back = safe_where(back)
|
||||||
|
front = safe_where(front)
|
||||||
|
down = safe_where(down)
|
||||||
|
up = safe_where(up)
|
||||||
|
|
||||||
|
# 计算符号一致性(转换为float32确保精度)
|
||||||
|
sign = torch.sign(val.to(torch.float32))
|
||||||
|
neighbors_sign = torch.stack([
|
||||||
|
torch.sign(left.to(torch.float32)),
|
||||||
|
torch.sign(right.to(torch.float32)),
|
||||||
|
torch.sign(back.to(torch.float32)),
|
||||||
|
torch.sign(front.to(torch.float32)),
|
||||||
|
torch.sign(down.to(torch.float32)),
|
||||||
|
torch.sign(up.to(torch.float32))
|
||||||
|
], dim=0)
|
||||||
|
|
||||||
|
# 检查所有符号是否一致
|
||||||
|
same_sign = torch.all(neighbors_sign == sign, dim=0)
|
||||||
|
|
||||||
|
# 生成最终掩码
|
||||||
|
mask = (~same_sign).to(torch.int32)
|
||||||
|
return mask * valid_mask.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dense_grid_points(
|
||||||
|
bbox_min: np.ndarray,
|
||||||
|
bbox_max: np.ndarray,
|
||||||
|
octree_resolution: int,
|
||||||
|
indexing: str = "ij",
|
||||||
|
):
|
||||||
|
length = bbox_max - bbox_min
|
||||||
|
num_cells = octree_resolution
|
||||||
|
|
||||||
|
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||||
|
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
||||||
|
xyz = np.stack((xs, ys, zs), axis=-1)
|
||||||
|
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
||||||
|
|
||||||
|
return xyz, grid_size, length
|
||||||
|
|
||||||
|
|
||||||
|
class VanillaVolumeDecoder:
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
latents: torch.FloatTensor,
|
||||||
|
geo_decoder: Callable,
|
||||||
|
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||||
|
num_chunks: int = 10000,
|
||||||
|
octree_resolution: int = None,
|
||||||
|
enable_pbar: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
device = latents.device
|
||||||
|
dtype = latents.dtype
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
|
||||||
|
# 1. generate query points
|
||||||
|
if isinstance(bounds, float):
|
||||||
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
|
||||||
|
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
||||||
|
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||||
|
bbox_min=bbox_min,
|
||||||
|
bbox_max=bbox_max,
|
||||||
|
octree_resolution=octree_resolution,
|
||||||
|
indexing="ij"
|
||||||
|
)
|
||||||
|
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||||
|
|
||||||
|
# 2. latents to 3d volume
|
||||||
|
batch_logits = []
|
||||||
|
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
|
||||||
|
disable=not enable_pbar):
|
||||||
|
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
||||||
|
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
||||||
|
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
|
||||||
|
grid_logits = torch.cat(batch_logits, dim=1)
|
||||||
|
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
||||||
|
|
||||||
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
|
class HierarchicalVolumeDecoding:
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
latents: torch.FloatTensor,
|
||||||
|
geo_decoder: Callable,
|
||||||
|
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||||
|
num_chunks: int = 10000,
|
||||||
|
mc_level: float = 0.0,
|
||||||
|
octree_resolution: int = None,
|
||||||
|
min_resolution: int = 63,
|
||||||
|
enable_pbar: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
device = latents.device
|
||||||
|
dtype = latents.dtype
|
||||||
|
|
||||||
|
resolutions = []
|
||||||
|
if octree_resolution < min_resolution:
|
||||||
|
resolutions.append(octree_resolution)
|
||||||
|
while octree_resolution >= min_resolution:
|
||||||
|
resolutions.append(octree_resolution)
|
||||||
|
octree_resolution = octree_resolution // 2
|
||||||
|
resolutions.reverse()
|
||||||
|
|
||||||
|
# 1. generate query points
|
||||||
|
if isinstance(bounds, float):
|
||||||
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
bbox_min = np.array(bounds[0:3])
|
||||||
|
bbox_max = np.array(bounds[3:6])
|
||||||
|
bbox_size = bbox_max - bbox_min
|
||||||
|
|
||||||
|
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||||
|
bbox_min=bbox_min,
|
||||||
|
bbox_max=bbox_max,
|
||||||
|
octree_resolution=resolutions[0],
|
||||||
|
indexing="ij"
|
||||||
|
)
|
||||||
|
|
||||||
|
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||||
|
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
grid_size = np.array(grid_size)
|
||||||
|
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
||||||
|
|
||||||
|
# 2. latents to 3d volume
|
||||||
|
batch_logits = []
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
||||||
|
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
|
||||||
|
queries = xyz_samples[start: start + num_chunks, :]
|
||||||
|
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||||
|
logits = geo_decoder(queries=batch_queries, latents=latents)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
|
||||||
|
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
||||||
|
|
||||||
|
for octree_depth_now in resolutions[1:]:
|
||||||
|
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||||
|
resolution = bbox_size / octree_depth_now
|
||||||
|
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||||
|
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||||
|
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||||
|
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||||
|
|
||||||
|
if octree_depth_now == resolutions[-1]:
|
||||||
|
expand_num = 0
|
||||||
|
else:
|
||||||
|
expand_num = 1
|
||||||
|
for i in range(expand_num):
|
||||||
|
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||||
|
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||||
|
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||||
|
for i in range(2 - expand_num):
|
||||||
|
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||||
|
nidx = torch.where(next_index > 0)
|
||||||
|
|
||||||
|
next_points = torch.stack(nidx, dim=1)
|
||||||
|
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
|
||||||
|
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
|
||||||
|
batch_logits = []
|
||||||
|
for start in tqdm(range(0, next_points.shape[0], num_chunks),
|
||||||
|
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
|
||||||
|
queries = next_points[start: start + num_chunks, :]
|
||||||
|
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
||||||
|
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
grid_logits = torch.cat(batch_logits, dim=1)
|
||||||
|
next_logits[nidx] = grid_logits[0, ..., 0]
|
||||||
|
grid_logits = next_logits.unsqueeze(0)
|
||||||
|
grid_logits[grid_logits == -10000.] = float('nan')
|
||||||
|
|
||||||
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
|
class FlashVDMVolumeDecoding:
|
||||||
|
def __init__(self, topk_mode='mean'):
|
||||||
|
if topk_mode not in ['mean', 'merge']:
|
||||||
|
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
|
||||||
|
|
||||||
|
if topk_mode == 'mean':
|
||||||
|
self.processor = FlashVDMCrossAttentionProcessor()
|
||||||
|
else:
|
||||||
|
self.processor = FlashVDMTopMCrossAttentionProcessor()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
latents: torch.FloatTensor,
|
||||||
|
geo_decoder: CrossAttentionDecoder,
|
||||||
|
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
||||||
|
num_chunks: int = 10000,
|
||||||
|
mc_level: float = 0.0,
|
||||||
|
octree_resolution: int = None,
|
||||||
|
min_resolution: int = 63,
|
||||||
|
mini_grid_num: int = 4,
|
||||||
|
enable_pbar: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
processor = self.processor
|
||||||
|
geo_decoder.set_cross_attention_processor(processor)
|
||||||
|
|
||||||
|
device = latents.device
|
||||||
|
dtype = latents.dtype
|
||||||
|
|
||||||
|
resolutions = []
|
||||||
|
if octree_resolution < min_resolution:
|
||||||
|
resolutions.append(octree_resolution)
|
||||||
|
while octree_resolution >= min_resolution:
|
||||||
|
resolutions.append(octree_resolution)
|
||||||
|
octree_resolution = octree_resolution // 2
|
||||||
|
resolutions.reverse()
|
||||||
|
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
||||||
|
for i, resolution in enumerate(resolutions[1:]):
|
||||||
|
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
||||||
|
|
||||||
|
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
|
||||||
|
|
||||||
|
# 1. generate query points
|
||||||
|
if isinstance(bounds, float):
|
||||||
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
||||||
|
bbox_min = np.array(bounds[0:3])
|
||||||
|
bbox_max = np.array(bounds[3:6])
|
||||||
|
bbox_size = bbox_max - bbox_min
|
||||||
|
|
||||||
|
xyz_samples, grid_size, length = generate_dense_grid_points(
|
||||||
|
bbox_min=bbox_min,
|
||||||
|
bbox_max=bbox_max,
|
||||||
|
octree_resolution=resolutions[0],
|
||||||
|
indexing="ij"
|
||||||
|
)
|
||||||
|
|
||||||
|
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
||||||
|
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
grid_size = np.array(grid_size)
|
||||||
|
|
||||||
|
# 2. latents to 3d volume
|
||||||
|
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
||||||
|
xyz_samples = xyz_samples.view(
|
||||||
|
mini_grid_num, mini_grid_size,
|
||||||
|
mini_grid_num, mini_grid_size,
|
||||||
|
mini_grid_num, mini_grid_size, 3
|
||||||
|
).permute(
|
||||||
|
0, 2, 4, 1, 3, 5, 6
|
||||||
|
).reshape(
|
||||||
|
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
||||||
|
)
|
||||||
|
batch_logits = []
|
||||||
|
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
||||||
|
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
|
||||||
|
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
|
||||||
|
queries = xyz_samples[start: start + num_batchs, :]
|
||||||
|
batch = queries.shape[0]
|
||||||
|
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
||||||
|
processor.topk = True
|
||||||
|
logits = geo_decoder(queries=queries, latents=batch_latents)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
||||||
|
mini_grid_num, mini_grid_num, mini_grid_num,
|
||||||
|
mini_grid_size, mini_grid_size,
|
||||||
|
mini_grid_size
|
||||||
|
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||||
|
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
for octree_depth_now in resolutions[1:]:
|
||||||
|
grid_size = np.array([octree_depth_now + 1] * 3)
|
||||||
|
resolution = bbox_size / octree_depth_now
|
||||||
|
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
||||||
|
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
||||||
|
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
||||||
|
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
||||||
|
|
||||||
|
if octree_depth_now == resolutions[-1]:
|
||||||
|
expand_num = 0
|
||||||
|
else:
|
||||||
|
expand_num = 1
|
||||||
|
for i in range(expand_num):
|
||||||
|
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
||||||
|
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
||||||
|
|
||||||
|
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
||||||
|
for i in range(2 - expand_num):
|
||||||
|
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
||||||
|
nidx = torch.where(next_index > 0)
|
||||||
|
|
||||||
|
next_points = torch.stack(nidx, dim=1)
|
||||||
|
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
||||||
|
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
||||||
|
|
||||||
|
query_grid_num = 6
|
||||||
|
min_val = next_points.min(axis=0).values
|
||||||
|
max_val = next_points.max(axis=0).values
|
||||||
|
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
||||||
|
index = torch.floor(vol_queries_index).long()
|
||||||
|
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
||||||
|
index = index.sort()
|
||||||
|
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
||||||
|
unique_values = torch.unique(index.values, return_counts=True)
|
||||||
|
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
||||||
|
input_grid = [[], []]
|
||||||
|
logits_grid_list = []
|
||||||
|
start_num = 0
|
||||||
|
sum_num = 0
|
||||||
|
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
||||||
|
if sum_num + count < num_chunks or sum_num == 0:
|
||||||
|
sum_num += count
|
||||||
|
input_grid[0].append(grid_index)
|
||||||
|
input_grid[1].append(count)
|
||||||
|
else:
|
||||||
|
processor.topk = input_grid
|
||||||
|
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||||
|
start_num = start_num + sum_num
|
||||||
|
logits_grid_list.append(logits_grid)
|
||||||
|
input_grid = [[grid_index], [count]]
|
||||||
|
sum_num = count
|
||||||
|
if sum_num > 0:
|
||||||
|
processor.topk = input_grid
|
||||||
|
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
||||||
|
logits_grid_list.append(logits_grid)
|
||||||
|
logits_grid = torch.cat(logits_grid_list, dim=1)
|
||||||
|
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
||||||
|
next_logits[nidx] = grid_logits
|
||||||
|
grid_logits = next_logits.unsqueeze(0)
|
||||||
|
|
||||||
|
grid_logits[grid_logits == -10000.] = float('nan')
|
||||||
|
|
||||||
|
return grid_logits
|
||||||
257
hy3dshape/hy3dshape/models/conditioner.py
Normal file
257
hy3dshape/hy3dshape/models/conditioner.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
# Open Source Model Licensed under the Apache License Version 2.0
|
||||||
|
# and Other Licenses of the Third-Party Components therein:
|
||||||
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||||
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||||
|
|
||||||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||||
|
# The below software and/or models in this distribution may have been
|
||||||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||||||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import (
|
||||||
|
CLIPVisionModelWithProjection,
|
||||||
|
CLIPVisionConfig,
|
||||||
|
Dinov2Model,
|
||||||
|
Dinov2Config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.
|
||||||
|
omega = 1. / 10000 ** omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
return np.concatenate([emb_sin, emb_cos], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
version=None,
|
||||||
|
config=None,
|
||||||
|
use_cls_token=True,
|
||||||
|
image_size=224,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
self.model = self.MODEL_CLASS.from_pretrained(version)
|
||||||
|
else:
|
||||||
|
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
|
||||||
|
self.model.eval()
|
||||||
|
self.model.requires_grad_(False)
|
||||||
|
self.use_cls_token = use_cls_token
|
||||||
|
self.size = image_size // 14
|
||||||
|
self.num_patches = (image_size // 14) ** 2
|
||||||
|
if self.use_cls_token:
|
||||||
|
self.num_patches += 1
|
||||||
|
|
||||||
|
self.transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||||
|
transforms.CenterCrop(image_size),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=self.mean,
|
||||||
|
std=self.std,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
|
||||||
|
if value_range is not None:
|
||||||
|
low, high = value_range
|
||||||
|
image = (image - low) / (high - low)
|
||||||
|
|
||||||
|
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||||
|
inputs = self.transform(image)
|
||||||
|
outputs = self.model(inputs)
|
||||||
|
|
||||||
|
last_hidden_state = outputs.last_hidden_state
|
||||||
|
if not self.use_cls_token:
|
||||||
|
last_hidden_state = last_hidden_state[:, 1:, :]
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
def unconditional_embedding(self, batch_size, **kwargs):
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
dtype = next(self.model.parameters()).dtype
|
||||||
|
zero = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.num_patches,
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return zero
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoder(ImageEncoder):
|
||||||
|
MODEL_CLASS = CLIPVisionModelWithProjection
|
||||||
|
MODEL_CONFIG_CLASS = CLIPVisionConfig
|
||||||
|
mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
|
||||||
|
|
||||||
|
class DinoImageEncoder(ImageEncoder):
|
||||||
|
MODEL_CLASS = Dinov2Model
|
||||||
|
MODEL_CONFIG_CLASS = Dinov2Config
|
||||||
|
mean = [0.485, 0.456, 0.406]
|
||||||
|
std = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
|
||||||
|
class DinoImageEncoderMV(DinoImageEncoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
version=None,
|
||||||
|
config=None,
|
||||||
|
use_cls_token=True,
|
||||||
|
image_size=224,
|
||||||
|
view_num=4,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(version, config, use_cls_token, image_size, **kwargs)
|
||||||
|
self.view_num = view_num
|
||||||
|
self.num_patches = self.num_patches
|
||||||
|
pos = np.arange(self.view_num, dtype=np.float32)
|
||||||
|
view_embedding = torch.from_numpy(
|
||||||
|
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
|
||||||
|
|
||||||
|
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
|
||||||
|
self.view_embed = view_embedding.unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
|
||||||
|
if value_range is not None:
|
||||||
|
low, high = value_range
|
||||||
|
image = (image - low) / (high - low)
|
||||||
|
|
||||||
|
image = image.to(self.model.device, dtype=self.model.dtype)
|
||||||
|
|
||||||
|
bs, num_views, c, h, w = image.shape
|
||||||
|
image = image.view(bs * num_views, c, h, w)
|
||||||
|
|
||||||
|
inputs = self.transform(image)
|
||||||
|
outputs = self.model(inputs)
|
||||||
|
|
||||||
|
last_hidden_state = outputs.last_hidden_state
|
||||||
|
last_hidden_state = last_hidden_state.view(
|
||||||
|
bs, num_views, last_hidden_state.shape[-2],
|
||||||
|
last_hidden_state.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||||
|
if view_idxs is not None:
|
||||||
|
assert len(view_idxs) == bs
|
||||||
|
view_embeddings = []
|
||||||
|
for i in range(bs):
|
||||||
|
view_idx = view_idxs[i]
|
||||||
|
assert num_views == len(view_idx)
|
||||||
|
view_embeddings.append(self.view_embed[:, view_idx, ...])
|
||||||
|
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
|
||||||
|
|
||||||
|
if num_views != self.view_num:
|
||||||
|
view_embedding = view_embedding[:, :num_views, ...]
|
||||||
|
last_hidden_state = last_hidden_state + view_embedding
|
||||||
|
last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
|
||||||
|
last_hidden_state.shape[-1])
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
dtype = next(self.model.parameters()).dtype
|
||||||
|
zero = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.num_patches * len(view_idxs[0]),
|
||||||
|
self.model.config.hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
return zero
|
||||||
|
|
||||||
|
|
||||||
|
def build_image_encoder(config):
|
||||||
|
if config['type'] == 'CLIPImageEncoder':
|
||||||
|
return CLIPImageEncoder(**config['kwargs'])
|
||||||
|
elif config['type'] == 'DinoImageEncoder':
|
||||||
|
return DinoImageEncoder(**config['kwargs'])
|
||||||
|
elif config['type'] == 'DinoImageEncoderMV':
|
||||||
|
return DinoImageEncoderMV(**config['kwargs'])
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
||||||
|
|
||||||
|
|
||||||
|
class DualImageEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_image_encoder,
|
||||||
|
additional_image_encoder,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||||
|
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
||||||
|
|
||||||
|
def forward(self, image, mask=None, **kwargs):
|
||||||
|
outputs = {
|
||||||
|
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||||
|
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def unconditional_embedding(self, batch_size, **kwargs):
|
||||||
|
outputs = {
|
||||||
|
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||||
|
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class SingleImageEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_image_encoder,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
||||||
|
|
||||||
|
def forward(self, image, mask=None, **kwargs):
|
||||||
|
outputs = {
|
||||||
|
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def unconditional_embedding(self, batch_size, **kwargs):
|
||||||
|
outputs = {
|
||||||
|
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
15
hy3dshape/hy3dshape/models/denoisers/__init__.py
Normal file
15
hy3dshape/hy3dshape/models/denoisers/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from .hunyuan3ddit import Hunyuan3DDiT
|
||||||
404
hy3dshape/hy3dshape/models/denoisers/hunyuan3ddit.py
Normal file
404
hy3dshape/hy3dshape/models/denoisers/hunyuan3ddit.py
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
# set up attention backend
|
||||||
|
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
||||||
|
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
||||||
|
scaled_dot_product_attention = sageattn
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q: Tensor, k: Tensor, v: Tensor, **kwargs) -> Tensor:
|
||||||
|
x = scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = time_factor * t
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||||
|
freqs = freqs.to(t.device)
|
||||||
|
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(t)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self, approximate='tanh'):
|
||||||
|
super().__init__()
|
||||||
|
self.approximate = approximate
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = RMSNorm(dim)
|
||||||
|
self.key_norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
q = self.query_norm(q)
|
||||||
|
k = self.key_norm(k)
|
||||||
|
return q.to(v), k.to(v)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
x = attention(q, k, v, pe=pe)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModulationOut:
|
||||||
|
shift: Tensor
|
||||||
|
scale: Tensor
|
||||||
|
gate: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, vec: Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||||
|
out = self.lin(nn.functional.silu(vec))[:, None, :]
|
||||||
|
out = out.chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
ModulationOut(*out[:3]),
|
||||||
|
ModulationOut(*out[3:]) if self.is_double else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with parallel linear layers as described in
|
||||||
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = GELU(approximate="tanh")
|
||||||
|
self.modulation = Modulation(hidden_size, double=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
return x + mod.gate * output
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3DDiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 64,
|
||||||
|
context_in_dim: int = 1536,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
num_heads: int = 16,
|
||||||
|
depth: int = 16,
|
||||||
|
depth_single_blocks: int = 32,
|
||||||
|
axes_dim: List[int] = [64],
|
||||||
|
theta: int = 10_000,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
time_factor: float = 1000,
|
||||||
|
guidance_embed: bool = False,
|
||||||
|
ckpt_path: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.context_in_dim = context_in_dim
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.depth = depth
|
||||||
|
self.depth_single_blocks = depth_single_blocks
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
self.theta = theta
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.time_factor = time_factor
|
||||||
|
self.out_channels = self.in_channels
|
||||||
|
self.guidance_embed = guidance_embed
|
||||||
|
|
||||||
|
if hidden_size % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = hidden_size // num_heads
|
||||||
|
if sum(axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.latent_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||||
|
self.cond_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||||
|
self.guidance_in = (
|
||||||
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if guidance_embed else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
for _ in range(depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
print('restored denoiser ckpt', ckpt_path)
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
if 'state_dict' not in ckpt:
|
||||||
|
# deepspeed ckpt
|
||||||
|
state_dict = {}
|
||||||
|
for k in ckpt.keys():
|
||||||
|
new_k = k.replace('_forward_module.', '')
|
||||||
|
state_dict[new_k] = ckpt[k]
|
||||||
|
else:
|
||||||
|
state_dict = ckpt["state_dict"]
|
||||||
|
|
||||||
|
final_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('model.'):
|
||||||
|
final_state_dict[k.replace('model.', '')] = v
|
||||||
|
else:
|
||||||
|
final_state_dict[k] = v
|
||||||
|
missing, unexpected = self.load_state_dict(final_state_dict, strict=False)
|
||||||
|
print('unexpected keys:', unexpected)
|
||||||
|
print('missing keys:', missing)
|
||||||
|
|
||||||
|
def forward(self, x, t, contexts, **kwargs) -> Tensor:
|
||||||
|
cond = contexts['main']
|
||||||
|
latent = self.latent_in(x)
|
||||||
|
|
||||||
|
vec = self.time_in(timestep_embedding(t, 256, self.time_factor).to(dtype=latent.dtype))
|
||||||
|
if self.guidance_embed:
|
||||||
|
guidance = kwargs.get('guidance', None)
|
||||||
|
if guidance is None:
|
||||||
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.time_factor))
|
||||||
|
|
||||||
|
cond = self.cond_in(cond)
|
||||||
|
pe = None
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
latent, cond = block(img=latent, txt=cond, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
latent = torch.cat((cond, latent), 1)
|
||||||
|
for block in self.single_blocks:
|
||||||
|
latent = block(latent, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
latent = latent[:, cond.shape[1]:, ...]
|
||||||
|
latent = self.final_layer(latent, vec)
|
||||||
|
return latent
|
||||||
596
hy3dshape/hy3dshape/models/denoisers/hunyuandit.py
Normal file
596
hy3dshape/hy3dshape/models/denoisers/hunyuandit.py
Normal file
@ -0,0 +1,596 @@
|
|||||||
|
# Open Source Model Licensed under the Apache License Version 2.0
|
||||||
|
# and Other Licenses of the Third-Party Components therein:
|
||||||
|
# The below Model in this distribution may have been modified by THL A29 Limited
|
||||||
|
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||||
|
|
||||||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||||
|
# The below software and/or models in this distribution may have been
|
||||||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||||||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from .moe_layers import MoEBlock
|
||||||
|
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||||
|
omega /= embed_dim / 2.
|
||||||
|
omega = 1. / 10000 ** omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
return np.concatenate([emb_sin, emb_cos], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_channels: int,
|
||||||
|
downscale_freq_shift: float = 0.0,
|
||||||
|
scale: int = 1,
|
||||||
|
max_period: int = 10000
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
self.scale = scale
|
||||||
|
self.max_period = max_period
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||||
|
embedding_dim = self.num_channels
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(self.max_period) * torch.arange(
|
||||||
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / (half_dim - self.downscale_freq_shift)
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
emb = self.scale * emb
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):
|
||||||
|
super().__init__()
|
||||||
|
if out_size is None:
|
||||||
|
out_size = hidden_size
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, frequency_embedding_size, bias=True),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(frequency_embedding_size, out_size, bias=True),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
if cond_proj_dim is not None:
|
||||||
|
self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)
|
||||||
|
|
||||||
|
self.time_embed = Timesteps(hidden_size)
|
||||||
|
|
||||||
|
def forward(self, t, condition):
|
||||||
|
|
||||||
|
t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
|
||||||
|
|
||||||
|
# t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||||
|
if condition is not None:
|
||||||
|
t_freq = t_freq + self.cond_proj(condition)
|
||||||
|
|
||||||
|
t = self.mlp(t_freq)
|
||||||
|
t = t.unsqueeze(dim=1)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, *, width: int):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.fc1 = nn.Linear(width, width * 4)
|
||||||
|
self.fc2 = nn.Linear(width * 4, width)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc2(self.gelu(self.fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qdim,
|
||||||
|
kdim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
with_decoupled_ca=False,
|
||||||
|
decoupled_ca_dim=16,
|
||||||
|
decoupled_ca_weight=1.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.qdim = qdim
|
||||||
|
self.kdim = kdim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||||
|
self.head_dim = self.qdim // num_heads
|
||||||
|
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
|
||||||
|
self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
|
||||||
|
self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
|
||||||
|
|
||||||
|
# TODO: eps should be 1 / 65530 if using fp16
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = nn.Linear(qdim, qdim, bias=True)
|
||||||
|
|
||||||
|
self.with_dca = with_decoupled_ca
|
||||||
|
if self.with_dca:
|
||||||
|
self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias)
|
||||||
|
self.k_norm_dca = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.dca_dim = decoupled_ca_dim
|
||||||
|
self.dca_weight = decoupled_ca_weight
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x: torch.Tensor
|
||||||
|
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||||
|
y: torch.Tensor
|
||||||
|
(batch, seqlen2, hidden_dim2)
|
||||||
|
freqs_cis_img: torch.Tensor
|
||||||
|
(batch, hidden_dim // 2), RoPE for image
|
||||||
|
"""
|
||||||
|
b, s1, c = x.shape # [b, s1, D]
|
||||||
|
|
||||||
|
if self.with_dca:
|
||||||
|
token_len = y.shape[1]
|
||||||
|
context_dca = y[:, -self.dca_dim:, :]
|
||||||
|
kv_dca = self.kv_proj_dca(context_dca).view(b, self.dca_dim, 2, self.num_heads, self.head_dim)
|
||||||
|
k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d]
|
||||||
|
k_dca = self.k_norm_dca(k_dca)
|
||||||
|
y = y[:, :(token_len - self.dca_dim), :]
|
||||||
|
|
||||||
|
_, s2, c = y.shape # [b, s2, 1024]
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(y)
|
||||||
|
v = self.to_v(y)
|
||||||
|
|
||||||
|
kv = torch.cat((k, v), dim=-1)
|
||||||
|
split_size = kv.shape[-1] // self.num_heads // 2
|
||||||
|
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
||||||
|
k, v = torch.split(kv, split_size, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||||
|
k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
||||||
|
v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=False,
|
||||||
|
enable_mem_efficient=True
|
||||||
|
):
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))
|
||||||
|
context = F.scaled_dot_product_attention(
|
||||||
|
q, k, v
|
||||||
|
).transpose(1, 2).reshape(b, s1, -1)
|
||||||
|
|
||||||
|
if self.with_dca:
|
||||||
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=False,
|
||||||
|
enable_mem_efficient=True
|
||||||
|
):
|
||||||
|
k_dca, v_dca = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads),
|
||||||
|
(k_dca, v_dca))
|
||||||
|
context_dca = F.scaled_dot_product_attention(
|
||||||
|
q, k_dca, v_dca).transpose(1, 2).reshape(b, s1, -1)
|
||||||
|
|
||||||
|
context = context + self.dca_weight * context_dca
|
||||||
|
|
||||||
|
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
We rename some layer names to align with flash attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.head_dim = self.dim // num_heads
|
||||||
|
# This assertion is aligned with flash attention
|
||||||
|
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
# TODO: eps should be 1 / 65530 if using fp16
|
||||||
|
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(x)
|
||||||
|
v = self.to_v(x)
|
||||||
|
|
||||||
|
qkv = torch.cat((q, k, v), dim=-1)
|
||||||
|
split_size = qkv.shape[-1] // self.num_heads // 3
|
||||||
|
qkv = qkv.view(1, -1, self.num_heads, split_size * 3)
|
||||||
|
q, k, v = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
||||||
|
k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
||||||
|
v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
q = self.q_norm(q) # [b, h, s, d]
|
||||||
|
k = self.k_norm(k) # [b, h, s, d]
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=False,
|
||||||
|
enable_mem_efficient=True
|
||||||
|
):
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanDiTBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
c_emb_size,
|
||||||
|
num_heads,
|
||||||
|
text_states_dim=1024,
|
||||||
|
use_flash_attn=False,
|
||||||
|
qk_norm=False,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
qk_norm_layer=nn.RMSNorm,
|
||||||
|
with_decoupled_ca=False,
|
||||||
|
decoupled_ca_dim=16,
|
||||||
|
decoupled_ca_weight=1.0,
|
||||||
|
init_scale=1.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
skip_connection=True,
|
||||||
|
timested_modulate=False,
|
||||||
|
use_moe: bool = False,
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
use_ele_affine = True
|
||||||
|
|
||||||
|
# ========================= Self-Attention =========================
|
||||||
|
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
||||||
|
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
||||||
|
norm_layer=qk_norm_layer)
|
||||||
|
|
||||||
|
# ========================= FFN =========================
|
||||||
|
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
||||||
|
|
||||||
|
# ========================= Add =========================
|
||||||
|
# Simply use add like SDXL.
|
||||||
|
self.timested_modulate = timested_modulate
|
||||||
|
if self.timested_modulate:
|
||||||
|
self.default_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(c_emb_size, hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================= Cross-Attention =========================
|
||||||
|
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm, norm_layer=qk_norm_layer,
|
||||||
|
with_decoupled_ca=with_decoupled_ca, decoupled_ca_dim=decoupled_ca_dim,
|
||||||
|
decoupled_ca_weight=decoupled_ca_weight, init_scale=init_scale,
|
||||||
|
)
|
||||||
|
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||||
|
|
||||||
|
if skip_connection:
|
||||||
|
self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
|
||||||
|
else:
|
||||||
|
self.skip_linear = None
|
||||||
|
|
||||||
|
self.use_moe = use_moe
|
||||||
|
if self.use_moe:
|
||||||
|
print("using moe")
|
||||||
|
self.moe = MoEBlock(
|
||||||
|
hidden_size,
|
||||||
|
num_experts=num_experts,
|
||||||
|
moe_top_k=moe_top_k,
|
||||||
|
dropout=0.0,
|
||||||
|
activation_fn="gelu",
|
||||||
|
final_dropout=False,
|
||||||
|
ff_inner_dim=int(hidden_size * 4.0),
|
||||||
|
ff_bias=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mlp = MLP(width=hidden_size)
|
||||||
|
|
||||||
|
def forward(self, x, c=None, text_states=None, skip_value=None):
|
||||||
|
|
||||||
|
if self.skip_linear is not None:
|
||||||
|
cat = torch.cat([skip_value, x], dim=-1)
|
||||||
|
x = self.skip_linear(cat)
|
||||||
|
x = self.skip_norm(x)
|
||||||
|
|
||||||
|
# Self-Attention
|
||||||
|
if self.timested_modulate:
|
||||||
|
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
||||||
|
x = x + shift_msa
|
||||||
|
|
||||||
|
attn_out = self.attn1(self.norm1(x))
|
||||||
|
|
||||||
|
x = x + attn_out
|
||||||
|
|
||||||
|
# Cross-Attention
|
||||||
|
x = x + self.attn2(self.norm2(x), text_states)
|
||||||
|
|
||||||
|
# FFN Layer
|
||||||
|
mlp_inputs = self.norm3(x)
|
||||||
|
|
||||||
|
if self.use_moe:
|
||||||
|
x = x + self.moe(mlp_inputs)
|
||||||
|
else:
|
||||||
|
x = x + self.mlp(mlp_inputs)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None):
|
||||||
|
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
|
||||||
|
global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
|
||||||
|
x = torch.cat([global_emb[None,], x], dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||||
|
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||||
|
x, _ = F.multi_head_attention_forward(
|
||||||
|
query=x[:1], key=x, value=x,
|
||||||
|
embed_dim_to_check=x.shape[-1],
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
in_proj_weight=None,
|
||||||
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||||
|
bias_k=None,
|
||||||
|
bias_v=None,
|
||||||
|
add_zero_attn=False,
|
||||||
|
dropout_p=0,
|
||||||
|
out_proj_weight=self.c_proj.weight,
|
||||||
|
out_proj_bias=self.c_proj.bias,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
training=self.training,
|
||||||
|
need_weights=False
|
||||||
|
)
|
||||||
|
return x.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of HunYuanDiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, final_hidden_size, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.final_hidden_size = final_hidden_size
|
||||||
|
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm_final(x)
|
||||||
|
x = x[:, 1:]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanDiTPlain(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size=1024,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size=1024,
|
||||||
|
context_dim=1024,
|
||||||
|
depth=24,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
norm_type='layer',
|
||||||
|
qk_norm_type='rms',
|
||||||
|
qk_norm=False,
|
||||||
|
text_len=257,
|
||||||
|
with_decoupled_ca=False,
|
||||||
|
additional_cond_hidden_state=768,
|
||||||
|
decoupled_ca_dim=16,
|
||||||
|
decoupled_ca_weight=1.0,
|
||||||
|
use_pos_emb=False,
|
||||||
|
use_attention_pooling=True,
|
||||||
|
guidance_cond_proj_dim=None,
|
||||||
|
qkv_bias=True,
|
||||||
|
num_moe_layers: int = 6,
|
||||||
|
num_experts: int = 8,
|
||||||
|
moe_top_k: int = 2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.depth = depth
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm
|
||||||
|
self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm
|
||||||
|
self.context_dim = context_dim
|
||||||
|
|
||||||
|
self.with_decoupled_ca = with_decoupled_ca
|
||||||
|
self.decoupled_ca_dim = decoupled_ca_dim
|
||||||
|
self.decoupled_ca_weight = decoupled_ca_weight
|
||||||
|
self.use_pos_emb = use_pos_emb
|
||||||
|
self.use_attention_pooling = use_attention_pooling
|
||||||
|
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
||||||
|
|
||||||
|
self.text_len = text_len
|
||||||
|
|
||||||
|
self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
||||||
|
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)
|
||||||
|
|
||||||
|
# Will use fixed sin-cos embedding:
|
||||||
|
if self.use_pos_emb:
|
||||||
|
self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size))
|
||||||
|
pos = np.arange(self.input_size, dtype=np.float32)
|
||||||
|
pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos)
|
||||||
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||||
|
|
||||||
|
self.use_attention_pooling = use_attention_pooling
|
||||||
|
if use_attention_pooling:
|
||||||
|
self.pooler = AttentionPool(self.text_len, context_dim, num_heads=8, output_dim=1024)
|
||||||
|
self.extra_embedder = nn.Sequential(
|
||||||
|
nn.Linear(1024, hidden_size * 4),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size * 4, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
if with_decoupled_ca:
|
||||||
|
self.additional_cond_hidden_state = additional_cond_hidden_state
|
||||||
|
self.additional_cond_proj = nn.Sequential(
|
||||||
|
nn.Linear(additional_cond_hidden_state, hidden_size * 4),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_size * 4, 1024, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
text_states_dim=context_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_layer=self.norm,
|
||||||
|
qk_norm_layer=self.qk_norm,
|
||||||
|
skip_connection=layer > depth // 2,
|
||||||
|
with_decoupled_ca=with_decoupled_ca,
|
||||||
|
decoupled_ca_dim=decoupled_ca_dim,
|
||||||
|
decoupled_ca_weight=decoupled_ca_weight,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_moe=True if depth - layer <= num_moe_layers else False,
|
||||||
|
num_experts=num_experts,
|
||||||
|
moe_top_k=moe_top_k
|
||||||
|
)
|
||||||
|
for layer in range(depth)
|
||||||
|
])
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, t, contexts, **kwargs):
|
||||||
|
cond = contexts['main']
|
||||||
|
|
||||||
|
t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
|
||||||
|
if self.use_pos_emb:
|
||||||
|
pos_embed = self.pos_embed.to(x.dtype)
|
||||||
|
x = x + pos_embed
|
||||||
|
|
||||||
|
if self.use_attention_pooling:
|
||||||
|
extra_vec = self.pooler(cond, None)
|
||||||
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
else:
|
||||||
|
c = t
|
||||||
|
|
||||||
|
if self.with_decoupled_ca:
|
||||||
|
additional_cond = self.additional_cond_proj(contexts['additional'])
|
||||||
|
cond = torch.cat([cond, additional_cond], dim=1)
|
||||||
|
|
||||||
|
x = torch.cat([c, x], dim=1)
|
||||||
|
|
||||||
|
skip_value_list = []
|
||||||
|
for layer, block in enumerate(self.blocks):
|
||||||
|
skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
|
||||||
|
x = block(x, c, cond, skip_value=skip_value)
|
||||||
|
if layer < self.depth // 2:
|
||||||
|
skip_value_list.append(x)
|
||||||
|
|
||||||
|
x = self.final_layer(x)
|
||||||
|
return x
|
||||||
177
hy3dshape/hy3dshape/models/denoisers/moe_layers.py
Normal file
177
hy3dshape/hy3dshape/models/denoisers/moe_layers.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention import FeedForward
|
||||||
|
|
||||||
|
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
The trick function of adding auxiliary (aux) loss,
|
||||||
|
which includes the gradient of the aux loss during backpropagation.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, loss):
|
||||||
|
assert loss.numel() == 1
|
||||||
|
ctx.dtype = loss.dtype
|
||||||
|
ctx.required_aux_loss = loss.requires_grad
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_loss = None
|
||||||
|
if ctx.required_aux_loss:
|
||||||
|
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
||||||
|
return grad_output, grad_loss
|
||||||
|
|
||||||
|
class MoEGate(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
|
||||||
|
super().__init__()
|
||||||
|
self.top_k = num_experts_per_tok
|
||||||
|
self.n_routed_experts = num_experts
|
||||||
|
|
||||||
|
self.scoring_func = 'softmax'
|
||||||
|
self.alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = False
|
||||||
|
|
||||||
|
# topk selection algorithm
|
||||||
|
self.norm_topk_prob = False
|
||||||
|
self.gating_dim = embed_dim
|
||||||
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
import torch.nn.init as init
|
||||||
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
bsz, seq_len, h = hidden_states.shape
|
||||||
|
# print(bsz, seq_len, h)
|
||||||
|
### compute gating score
|
||||||
|
hidden_states = hidden_states.view(-1, h)
|
||||||
|
logits = F.linear(hidden_states, self.weight, None)
|
||||||
|
if self.scoring_func == 'softmax':
|
||||||
|
scores = logits.softmax(dim=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||||
|
|
||||||
|
### select top-k experts
|
||||||
|
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
### norm gate to sum 1
|
||||||
|
if self.top_k > 1 and self.norm_topk_prob:
|
||||||
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weight = topk_weight / denominator
|
||||||
|
|
||||||
|
### expert-level computation auxiliary loss
|
||||||
|
if self.training and self.alpha > 0.0:
|
||||||
|
scores_for_aux = scores
|
||||||
|
aux_topk = self.top_k
|
||||||
|
# always compute aux loss based on the naive greedy topk method
|
||||||
|
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||||
|
if self.seq_aux:
|
||||||
|
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||||
|
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||||
|
ce.scatter_add_(
|
||||||
|
1,
|
||||||
|
topk_idx_for_aux_loss,
|
||||||
|
torch.ones(
|
||||||
|
bsz, seq_len * aux_topk,
|
||||||
|
device=hidden_states.device
|
||||||
|
)
|
||||||
|
).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||||
|
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()
|
||||||
|
aux_loss = aux_loss * self.alpha
|
||||||
|
else:
|
||||||
|
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),
|
||||||
|
num_classes=self.n_routed_experts)
|
||||||
|
ce = mask_ce.float().mean(0)
|
||||||
|
Pi = scores_for_aux.mean(0)
|
||||||
|
fi = ce * self.n_routed_experts
|
||||||
|
aux_loss = (Pi * fi).sum() * self.alpha
|
||||||
|
else:
|
||||||
|
aux_loss = None
|
||||||
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
class MoEBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_experts=8, moe_top_k=2,
|
||||||
|
activation_fn = "gelu", dropout=0.0, final_dropout = False,
|
||||||
|
ff_inner_dim = None, ff_bias = True):
|
||||||
|
super().__init__()
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.experts = nn.ModuleList([
|
||||||
|
FeedForward(dim,dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
inner_dim=ff_inner_dim,
|
||||||
|
bias=ff_bias)
|
||||||
|
for i in range(num_experts)])
|
||||||
|
self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)
|
||||||
|
|
||||||
|
self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout, inner_dim=ff_inner_dim,
|
||||||
|
bias=ff_bias)
|
||||||
|
|
||||||
|
def initialize_weight(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
identity = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
if self.training:
|
||||||
|
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
|
||||||
|
y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
|
||||||
|
for i, expert in enumerate(self.experts):
|
||||||
|
tmp = expert(hidden_states[flat_topk_idx == i])
|
||||||
|
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
y = y.view(*orig_shape)
|
||||||
|
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||||
|
else:
|
||||||
|
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
|
y = y + self.shared_experts(identity)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
|
expert_cache = torch.zeros_like(x)
|
||||||
|
idxs = flat_expert_indices.argsort()
|
||||||
|
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||||
|
token_idxs = idxs // self.moe_top_k
|
||||||
|
for i, end_idx in enumerate(tokens_per_expert):
|
||||||
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
||||||
|
if start_idx == end_idx:
|
||||||
|
continue
|
||||||
|
expert = self.experts[i]
|
||||||
|
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||||
|
expert_tokens = x[exp_token_idx]
|
||||||
|
expert_out = expert(expert_tokens)
|
||||||
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||||
|
|
||||||
|
# for fp16 and other dtype
|
||||||
|
expert_cache = expert_cache.to(expert_out.dtype)
|
||||||
|
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
|
||||||
|
expert_out,
|
||||||
|
reduce='sum')
|
||||||
|
return expert_cache
|
||||||
354
hy3dshape/hy3dshape/models/diffusion/flow_matching_sit.py
Normal file
354
hy3dshape/hy3dshape/models/diffusion/flow_matching_sit.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import List, Tuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
from ...utils.ema import LitEma
|
||||||
|
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Diffuser(pl.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
first_stage_config,
|
||||||
|
cond_stage_config,
|
||||||
|
denoiser_cfg,
|
||||||
|
scheduler_cfg,
|
||||||
|
optimizer_cfg,
|
||||||
|
pipeline_cfg=None,
|
||||||
|
image_processor_cfg=None,
|
||||||
|
lora_config=None,
|
||||||
|
ema_config=None,
|
||||||
|
first_stage_key: str = "surface",
|
||||||
|
cond_stage_key: str = "image",
|
||||||
|
scale_by_std: bool = False,
|
||||||
|
z_scale_factor: float = 1.0,
|
||||||
|
ckpt_path: Optional[str] = None,
|
||||||
|
ignore_keys: Union[Tuple[str], List[str]] = (),
|
||||||
|
torch_compile: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.first_stage_key = first_stage_key
|
||||||
|
self.cond_stage_key = cond_stage_key
|
||||||
|
|
||||||
|
# ========= init optimizer config ========= #
|
||||||
|
self.optimizer_cfg = optimizer_cfg
|
||||||
|
|
||||||
|
# ========= init diffusion scheduler ========= #
|
||||||
|
self.scheduler_cfg = scheduler_cfg
|
||||||
|
self.sampler = None
|
||||||
|
if 'transport' in scheduler_cfg:
|
||||||
|
self.transport = instantiate_from_config(scheduler_cfg.transport)
|
||||||
|
self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)
|
||||||
|
self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)
|
||||||
|
|
||||||
|
# ========= init the model ========= #
|
||||||
|
self.denoiser_cfg = denoiser_cfg
|
||||||
|
self.model = instantiate_from_config(denoiser_cfg, device=None, dtype=None)
|
||||||
|
self.cond_stage_model = instantiate_from_config(cond_stage_config)
|
||||||
|
|
||||||
|
self.ckpt_path = ckpt_path
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
# ========= config lora model ========= #
|
||||||
|
if lora_config is not None:
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
loraconfig = LoraConfig(
|
||||||
|
r=lora_config.rank,
|
||||||
|
lora_alpha=lora_config.rank,
|
||||||
|
target_modules=lora_config.get('target_modules')
|
||||||
|
)
|
||||||
|
self.model = get_peft_model(self.model, loraconfig)
|
||||||
|
|
||||||
|
# ========= config ema model ========= #
|
||||||
|
self.ema_config = ema_config
|
||||||
|
if self.ema_config is not None:
|
||||||
|
if self.ema_config.ema_model == 'DSEma':
|
||||||
|
# from michelangelo.models.modules.ema_deepspeed import DSEma
|
||||||
|
from ..utils.ema_deepspeed import DSEma
|
||||||
|
self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)
|
||||||
|
else:
|
||||||
|
self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)
|
||||||
|
#do not initilize EMA weight from ckpt path, since I need to change moe layers
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
|
# ========= init vae at last to prevent it is overridden by loaded ckpt ========= #
|
||||||
|
self.first_stage_model = instantiate_non_trainable_model(first_stage_config)
|
||||||
|
|
||||||
|
self.scale_by_std = scale_by_std
|
||||||
|
if scale_by_std:
|
||||||
|
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
|
||||||
|
else:
|
||||||
|
self.z_scale_factor = z_scale_factor
|
||||||
|
|
||||||
|
# ========= init pipeline for inference ========= #
|
||||||
|
self.image_processor_cfg = image_processor_cfg
|
||||||
|
self.image_processor = None
|
||||||
|
if self.image_processor_cfg is not None:
|
||||||
|
self.image_processor = instantiate_from_config(self.image_processor_cfg)
|
||||||
|
self.pipeline_cfg = pipeline_cfg
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||||
|
self.pipeline = instantiate_from_config(
|
||||||
|
pipeline_cfg,
|
||||||
|
vae=self.first_stage_model,
|
||||||
|
model=self.model,
|
||||||
|
scheduler=scheduler, # self.sampler,
|
||||||
|
conditioner=self.cond_stage_model,
|
||||||
|
image_processor=self.image_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========= torch compile to accelerate ========= #
|
||||||
|
self.torch_compile = torch_compile
|
||||||
|
if self.torch_compile:
|
||||||
|
torch.nn.Module.compile(self.model)
|
||||||
|
torch.nn.Module.compile(self.first_stage_model)
|
||||||
|
torch.nn.Module.compile(self.cond_stage_model)
|
||||||
|
print(f'*' * 100)
|
||||||
|
print(f'Compile model for acceleration')
|
||||||
|
print(f'*' * 100)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
||||||
|
self.model_ema.store(self.model)
|
||||||
|
self.model_ema.copy_to(self.model)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Switched to EMA weights")
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
||||||
|
self.model_ema.restore(self.model)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=()):
|
||||||
|
ckpt = torch.load(path, map_location="cpu")
|
||||||
|
if 'state_dict' not in ckpt:
|
||||||
|
# deepspeed ckpt
|
||||||
|
state_dict = {}
|
||||||
|
for k in ckpt.keys():
|
||||||
|
new_k = k.replace('_forward_module.', '')
|
||||||
|
state_dict[new_k] = ckpt[k]
|
||||||
|
else:
|
||||||
|
state_dict = ckpt["state_dict"]
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if ik in k:
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def on_load_checkpoint(self, checkpoint):
|
||||||
|
"""
|
||||||
|
The pt_model is trained separately, so we already have access to its
|
||||||
|
checkpoint and load it separately with `self.set_pt_model`.
|
||||||
|
|
||||||
|
However, the PL Trainer is strict about
|
||||||
|
checkpoint loading (not configurable), so it expects the loaded state_dict
|
||||||
|
to match exactly the keys in the model state_dict.
|
||||||
|
|
||||||
|
So, when loading the checkpoint, before matching keys, we add all pt_model keys
|
||||||
|
from self.state_dict() to the checkpoint state dict, so that they match
|
||||||
|
"""
|
||||||
|
for key in self.state_dict().keys():
|
||||||
|
if key.startswith("model_ema") and key not in checkpoint["state_dict"]:
|
||||||
|
checkpoint["state_dict"][key] = self.state_dict()[key]
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> Tuple[List, List]:
|
||||||
|
lr = self.learning_rate
|
||||||
|
|
||||||
|
params_list = []
|
||||||
|
trainable_parameters = list(self.model.parameters())
|
||||||
|
params_list.append({'params': trainable_parameters, 'lr': lr})
|
||||||
|
|
||||||
|
no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']
|
||||||
|
|
||||||
|
|
||||||
|
if self.optimizer_cfg.get('train_image_encoder', False):
|
||||||
|
image_encoder_parameters = list(self.cond_stage_model.named_parameters())
|
||||||
|
image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if
|
||||||
|
not any((no_decay_name in name) for no_decay_name in no_decay)]
|
||||||
|
image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if
|
||||||
|
any((no_decay_name in name) for no_decay_name in no_decay)]
|
||||||
|
# filter trainable params
|
||||||
|
image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if
|
||||||
|
param.requires_grad]
|
||||||
|
image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if
|
||||||
|
param.requires_grad]
|
||||||
|
|
||||||
|
print(f"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, ")
|
||||||
|
print(f"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, ")
|
||||||
|
|
||||||
|
image_encoder_lr = self.optimizer_cfg['image_encoder_lr']
|
||||||
|
image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)
|
||||||
|
image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply
|
||||||
|
params_list.append(
|
||||||
|
{'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,
|
||||||
|
'weight_decay': 0.05})
|
||||||
|
params_list.append(
|
||||||
|
{'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,
|
||||||
|
'weight_decay': 0.})
|
||||||
|
|
||||||
|
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
|
||||||
|
if hasattr(self.optimizer_cfg, 'scheduler'):
|
||||||
|
scheduler_func = instantiate_from_config(
|
||||||
|
self.optimizer_cfg.scheduler,
|
||||||
|
max_decay_steps=self.trainer.max_steps,
|
||||||
|
lr_max=lr
|
||||||
|
)
|
||||||
|
scheduler = {
|
||||||
|
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
||||||
|
"interval": "step",
|
||||||
|
"frequency": 1
|
||||||
|
}
|
||||||
|
schedulers = [scheduler]
|
||||||
|
else:
|
||||||
|
schedulers = []
|
||||||
|
optimizers = [optimizer]
|
||||||
|
|
||||||
|
return optimizers, schedulers
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
@torch.no_grad()
|
||||||
|
def on_train_batch_start(self, batch, batch_idx):
|
||||||
|
# only for very first batch
|
||||||
|
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
|
||||||
|
and batch_idx == 0 and self.ckpt_path is None:
|
||||||
|
# set rescale weight to 1./std of encodings
|
||||||
|
print("### USING STD-RESCALING ###")
|
||||||
|
|
||||||
|
z_q = self.encode_first_stage(batch[self.first_stage_key])
|
||||||
|
z = z_q.detach()
|
||||||
|
|
||||||
|
del self.z_scale_factor
|
||||||
|
self.register_buffer("z_scale_factor", 1. / z.flatten().std())
|
||||||
|
print(f"setting self.z_scale_factor to {self.z_scale_factor}")
|
||||||
|
|
||||||
|
print("### USING STD-RESCALING ###")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.ema_config is not None:
|
||||||
|
self.model_ema(self.model)
|
||||||
|
|
||||||
|
def on_train_epoch_start(self) -> None:
|
||||||
|
pl.seed_everything(self.trainer.global_rank)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
|
||||||
|
contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'))
|
||||||
|
# t5_text = contexts['t5_text']['prompt_embeds']
|
||||||
|
# nan_count = torch.isnan(t5_text).sum()
|
||||||
|
# if nan_count > 0:
|
||||||
|
# print("t5_text has %d NaN values"%(nan_count))
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = self.first_stage_model.encode(batch[self.first_stage_key], sample_posterior=True)
|
||||||
|
latents = self.z_scale_factor * latents
|
||||||
|
# print(latents.shape)
|
||||||
|
|
||||||
|
# check vae encode and decode is ok? answer is ok !
|
||||||
|
# import time
|
||||||
|
# from hy3dshape.pipelines import export_to_trimesh
|
||||||
|
# latents = 1. / self.z_scale_factor * latents
|
||||||
|
# latents = self.first_stage_model(latents)
|
||||||
|
# outputs = self.first_stage_model.latents2mesh(
|
||||||
|
# latents,
|
||||||
|
# bounds=1.01,
|
||||||
|
# mc_level=0.0,
|
||||||
|
# num_chunks=20000,
|
||||||
|
# octree_resolution=256,
|
||||||
|
# mc_algo='mc',
|
||||||
|
# enable_pbar=True
|
||||||
|
# )
|
||||||
|
# mesh = export_to_trimesh(outputs)
|
||||||
|
# if isinstance(mesh, list):
|
||||||
|
# for midx, m in enumerate(mesh):
|
||||||
|
# m.export(f"check_{midx}_{time.time()}.glb")
|
||||||
|
# else:
|
||||||
|
# mesh.export(f"check_{time.time()}.glb")
|
||||||
|
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
loss = self.transport.training_losses(self.model, latents, dict(contexts=contexts))["loss"].mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
||||||
|
loss = self.forward(batch)
|
||||||
|
split = 'train'
|
||||||
|
loss_dict = {
|
||||||
|
f"{split}/simple": loss.detach(),
|
||||||
|
f"{split}/total_loss": loss.detach(),
|
||||||
|
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
||||||
|
}
|
||||||
|
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx, optimizer_idx=0):
|
||||||
|
loss = self.forward(batch)
|
||||||
|
split = 'val'
|
||||||
|
loss_dict = {
|
||||||
|
f"{split}/simple": loss.detach(),
|
||||||
|
f"{split}/total_loss": loss.detach(),
|
||||||
|
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
||||||
|
}
|
||||||
|
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, batch, output_type='trimesh', **kwargs):
|
||||||
|
self.cond_stage_model.disable_drop = True
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(0)
|
||||||
|
|
||||||
|
with self.ema_scope("Sample"):
|
||||||
|
with torch.amp.autocast(device_type='cuda'):
|
||||||
|
try:
|
||||||
|
self.pipeline.device = self.device
|
||||||
|
self.pipeline.dtype = self.dtype
|
||||||
|
print("### USING PIPELINE ###")
|
||||||
|
print(f'device: {self.device} dtype : {self.dtype}')
|
||||||
|
additional_params = {'output_type':output_type}
|
||||||
|
|
||||||
|
image = batch.get("image", None)
|
||||||
|
mask = batch.get('mask', None)
|
||||||
|
|
||||||
|
# if not isinstance(image, torch.Tensor): print(image.shape)
|
||||||
|
# if isinstance(mask, torch.Tensor): print(mask.shape)
|
||||||
|
|
||||||
|
outputs = self.pipeline(image=image,
|
||||||
|
mask=mask,
|
||||||
|
generator=generator,
|
||||||
|
**additional_params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Unexpected {e=}, {type(e)=}")
|
||||||
|
with open("error.txt", "a") as f:
|
||||||
|
f.write(str(e))
|
||||||
|
f.write(traceback.format_exc())
|
||||||
|
f.write("\n")
|
||||||
|
outputs = [None]
|
||||||
|
self.cond_stage_model.disable_drop = False
|
||||||
|
return [outputs]
|
||||||
97
hy3dshape/hy3dshape/models/diffusion/transport/__init__.py
Normal file
97
hy3dshape/hy3dshape/models/diffusion/transport/__init__.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||||
|
# which is licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
||||||
|
|
||||||
|
|
||||||
|
def create_transport(
|
||||||
|
path_type='Linear',
|
||||||
|
prediction="velocity",
|
||||||
|
loss_weight=None,
|
||||||
|
train_eps=None,
|
||||||
|
sample_eps=None,
|
||||||
|
train_sample_type="uniform",
|
||||||
|
mean = 0.0,
|
||||||
|
std = 1.0,
|
||||||
|
shift_scale = 1.0,
|
||||||
|
):
|
||||||
|
"""function for creating Transport object
|
||||||
|
**Note**: model prediction defaults to velocity
|
||||||
|
Args:
|
||||||
|
- path_type: type of path to use; default to linear
|
||||||
|
- learn_score: set model prediction to score
|
||||||
|
- learn_noise: set model prediction to noise
|
||||||
|
- velocity_weighted: weight loss by velocity weight
|
||||||
|
- likelihood_weighted: weight loss by likelihood weight
|
||||||
|
- train_eps: small epsilon for avoiding instability during training
|
||||||
|
- sample_eps: small epsilon for avoiding instability during sampling
|
||||||
|
"""
|
||||||
|
|
||||||
|
if prediction == "noise":
|
||||||
|
model_type = ModelType.NOISE
|
||||||
|
elif prediction == "score":
|
||||||
|
model_type = ModelType.SCORE
|
||||||
|
else:
|
||||||
|
model_type = ModelType.VELOCITY
|
||||||
|
|
||||||
|
if loss_weight == "velocity":
|
||||||
|
loss_type = WeightType.VELOCITY
|
||||||
|
elif loss_weight == "likelihood":
|
||||||
|
loss_type = WeightType.LIKELIHOOD
|
||||||
|
else:
|
||||||
|
loss_type = WeightType.NONE
|
||||||
|
|
||||||
|
path_choice = {
|
||||||
|
"Linear": PathType.LINEAR,
|
||||||
|
"GVP": PathType.GVP,
|
||||||
|
"VP": PathType.VP,
|
||||||
|
}
|
||||||
|
|
||||||
|
path_type = path_choice[path_type]
|
||||||
|
|
||||||
|
if (path_type in [PathType.VP]):
|
||||||
|
train_eps = 1e-5 if train_eps is None else train_eps
|
||||||
|
sample_eps = 1e-3 if train_eps is None else sample_eps
|
||||||
|
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
||||||
|
train_eps = 1e-3 if train_eps is None else train_eps
|
||||||
|
sample_eps = 1e-3 if train_eps is None else sample_eps
|
||||||
|
else: # velocity & [GVP, LINEAR] is stable everywhere
|
||||||
|
train_eps = 0
|
||||||
|
sample_eps = 0
|
||||||
|
|
||||||
|
# create flow state
|
||||||
|
state = Transport(
|
||||||
|
model_type=model_type,
|
||||||
|
path_type=path_type,
|
||||||
|
loss_type=loss_type,
|
||||||
|
train_eps=train_eps,
|
||||||
|
sample_eps=sample_eps,
|
||||||
|
train_sample_type=train_sample_type,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
shift_scale =shift_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state
|
||||||
142
hy3dshape/hy3dshape/models/diffusion/transport/integrators.py
Normal file
142
hy3dshape/hy3dshape/models/diffusion/transport/integrators.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||||
|
# which is licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
from functools import partial
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class sde:
|
||||||
|
"""SDE solver class"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drift,
|
||||||
|
diffusion,
|
||||||
|
*,
|
||||||
|
t0,
|
||||||
|
t1,
|
||||||
|
num_steps,
|
||||||
|
sampler_type,
|
||||||
|
):
|
||||||
|
assert t0 < t1, "SDE sampler has to be in forward time"
|
||||||
|
|
||||||
|
self.num_timesteps = num_steps
|
||||||
|
self.t = th.linspace(t0, t1, num_steps)
|
||||||
|
self.dt = self.t[1] - self.t[0]
|
||||||
|
self.drift = drift
|
||||||
|
self.diffusion = diffusion
|
||||||
|
self.sampler_type = sampler_type
|
||||||
|
|
||||||
|
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
||||||
|
w_cur = th.randn(x.size()).to(x)
|
||||||
|
t = th.ones(x.size(0)).to(x) * t
|
||||||
|
dw = w_cur * th.sqrt(self.dt)
|
||||||
|
drift = self.drift(x, t, model, **model_kwargs)
|
||||||
|
diffusion = self.diffusion(x, t)
|
||||||
|
mean_x = x + drift * self.dt
|
||||||
|
x = mean_x + th.sqrt(2 * diffusion) * dw
|
||||||
|
return x, mean_x
|
||||||
|
|
||||||
|
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
||||||
|
w_cur = th.randn(x.size()).to(x)
|
||||||
|
dw = w_cur * th.sqrt(self.dt)
|
||||||
|
t_cur = th.ones(x.size(0)).to(x) * t
|
||||||
|
diffusion = self.diffusion(x, t_cur)
|
||||||
|
xhat = x + th.sqrt(2 * diffusion) * dw
|
||||||
|
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
||||||
|
xp = xhat + self.dt * K1
|
||||||
|
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
||||||
|
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
||||||
|
|
||||||
|
def __forward_fn(self):
|
||||||
|
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
||||||
|
sampler_dict = {
|
||||||
|
"Euler": self.__Euler_Maruyama_step,
|
||||||
|
"Heun": self.__Heun_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
sampler = sampler_dict[self.sampler_type]
|
||||||
|
except:
|
||||||
|
raise NotImplementedError("Smapler type not implemented.")
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
|
def sample(self, init, model, **model_kwargs):
|
||||||
|
"""forward loop of sde"""
|
||||||
|
x = init
|
||||||
|
mean_x = init
|
||||||
|
samples = []
|
||||||
|
sampler = self.__forward_fn()
|
||||||
|
for ti in self.t[:-1]:
|
||||||
|
with th.no_grad():
|
||||||
|
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
||||||
|
samples.append(x)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
class ode:
|
||||||
|
"""ODE solver class"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drift,
|
||||||
|
*,
|
||||||
|
t0,
|
||||||
|
t1,
|
||||||
|
sampler_type,
|
||||||
|
num_steps,
|
||||||
|
atol,
|
||||||
|
rtol,
|
||||||
|
):
|
||||||
|
assert t0 < t1, "ODE sampler has to be in forward time"
|
||||||
|
|
||||||
|
self.drift = drift
|
||||||
|
self.t = th.linspace(t0, t1, num_steps)
|
||||||
|
self.atol = atol
|
||||||
|
self.rtol = rtol
|
||||||
|
self.sampler_type = sampler_type
|
||||||
|
|
||||||
|
def sample(self, x, model, **model_kwargs):
|
||||||
|
|
||||||
|
device = x[0].device if isinstance(x, tuple) else x.device
|
||||||
|
def _fn(t, x):
|
||||||
|
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
||||||
|
model_output = self.drift(x, t, model, **model_kwargs)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
t = self.t.to(device)
|
||||||
|
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
||||||
|
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
||||||
|
samples = odeint(
|
||||||
|
_fn,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
method=self.sampler_type,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol
|
||||||
|
)
|
||||||
|
return samples
|
||||||
220
hy3dshape/hy3dshape/models/diffusion/transport/path.py
Normal file
220
hy3dshape/hy3dshape/models/diffusion/transport/path.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||||
|
# which is licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
import numpy as np
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
def expand_t_like_x(t, x):
|
||||||
|
"""Function to reshape time t to broadcastable dimension of x
|
||||||
|
Args:
|
||||||
|
t: [batch_dim,], time vector
|
||||||
|
x: [batch_dim,...], data point
|
||||||
|
"""
|
||||||
|
dims = [1] * (len(x.size()) - 1)
|
||||||
|
t = t.view(t.size(0), *dims)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
#################### Coupling Plans ####################
|
||||||
|
|
||||||
|
class ICPlan:
|
||||||
|
"""Linear Coupling Plan"""
|
||||||
|
def __init__(self, sigma=0.0):
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def compute_alpha_t(self, t):
|
||||||
|
"""Compute the data coefficient along the path"""
|
||||||
|
return t, 1
|
||||||
|
|
||||||
|
def compute_sigma_t(self, t):
|
||||||
|
"""Compute the noise coefficient along the path"""
|
||||||
|
return 1 - t, -1
|
||||||
|
|
||||||
|
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||||
|
"""Compute the ratio between d_alpha and alpha"""
|
||||||
|
return 1 / t
|
||||||
|
|
||||||
|
def compute_drift(self, x, t):
|
||||||
|
"""We always output sde according to score parametrization; """
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
||||||
|
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||||
|
drift = alpha_ratio * x
|
||||||
|
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
||||||
|
|
||||||
|
return -drift, diffusion
|
||||||
|
|
||||||
|
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
||||||
|
"""Compute the diffusion term of the SDE
|
||||||
|
Args:
|
||||||
|
x: [batch_dim, ...], data point
|
||||||
|
t: [batch_dim,], time vector
|
||||||
|
form: str, form of the diffusion term
|
||||||
|
norm: float, norm of the diffusion term
|
||||||
|
"""
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
choices = {
|
||||||
|
"constant": norm,
|
||||||
|
"SBDM": norm * self.compute_drift(x, t)[1],
|
||||||
|
"sigma": norm * self.compute_sigma_t(t)[0],
|
||||||
|
"linear": norm * (1 - t),
|
||||||
|
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
||||||
|
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
diffusion = choices[form]
|
||||||
|
except KeyError:
|
||||||
|
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
||||||
|
|
||||||
|
return diffusion
|
||||||
|
|
||||||
|
def get_score_from_velocity(self, velocity, x, t):
|
||||||
|
"""Wrapper function: transfrom velocity prediction model to score
|
||||||
|
Args:
|
||||||
|
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
||||||
|
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||||
|
t: [batch_dim,] time tensor
|
||||||
|
"""
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
||||||
|
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||||
|
mean = x
|
||||||
|
reverse_alpha_ratio = alpha_t / d_alpha_t
|
||||||
|
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
||||||
|
score = (reverse_alpha_ratio * velocity - mean) / var
|
||||||
|
return score
|
||||||
|
|
||||||
|
def get_noise_from_velocity(self, velocity, x, t):
|
||||||
|
"""Wrapper function: transfrom velocity prediction model to denoiser
|
||||||
|
Args:
|
||||||
|
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
||||||
|
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||||
|
t: [batch_dim,] time tensor
|
||||||
|
"""
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
||||||
|
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
||||||
|
mean = x
|
||||||
|
reverse_alpha_ratio = alpha_t / d_alpha_t
|
||||||
|
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
||||||
|
noise = (reverse_alpha_ratio * velocity - mean) / var
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def get_velocity_from_score(self, score, x, t):
|
||||||
|
"""Wrapper function: transfrom score prediction model to velocity
|
||||||
|
Args:
|
||||||
|
score: [batch_dim, ...] shaped tensor; score model output
|
||||||
|
x: [batch_dim, ...] shaped tensor; x_t data point
|
||||||
|
t: [batch_dim,] time tensor
|
||||||
|
"""
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
drift, var = self.compute_drift(x, t)
|
||||||
|
velocity = var * score - drift
|
||||||
|
return velocity
|
||||||
|
|
||||||
|
def compute_mu_t(self, t, x0, x1):
|
||||||
|
"""Compute the mean of time-dependent density p_t"""
|
||||||
|
t = expand_t_like_x(t, x1)
|
||||||
|
alpha_t, _ = self.compute_alpha_t(t)
|
||||||
|
sigma_t, _ = self.compute_sigma_t(t)
|
||||||
|
# t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
|
||||||
|
return alpha_t * x1 + sigma_t * x0
|
||||||
|
|
||||||
|
def compute_xt(self, t, x0, x1):
|
||||||
|
"""Sample xt from time-dependent density p_t; rng is required"""
|
||||||
|
xt = self.compute_mu_t(t, x0, x1)
|
||||||
|
return xt
|
||||||
|
|
||||||
|
def compute_ut(self, t, x0, x1, xt):
|
||||||
|
"""Compute the vector field corresponding to p_t"""
|
||||||
|
t = expand_t_like_x(t, x1)
|
||||||
|
_, d_alpha_t = self.compute_alpha_t(t)
|
||||||
|
_, d_sigma_t = self.compute_sigma_t(t)
|
||||||
|
return d_alpha_t * x1 + d_sigma_t * x0
|
||||||
|
|
||||||
|
def plan(self, t, x0, x1):
|
||||||
|
xt = self.compute_xt(t, x0, x1)
|
||||||
|
ut = self.compute_ut(t, x0, x1, xt)
|
||||||
|
return t, xt, ut
|
||||||
|
|
||||||
|
|
||||||
|
class VPCPlan(ICPlan):
|
||||||
|
"""class for VP path flow matching"""
|
||||||
|
|
||||||
|
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
||||||
|
self.sigma_min = sigma_min
|
||||||
|
self.sigma_max = sigma_max
|
||||||
|
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
|
||||||
|
(self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
||||||
|
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
|
||||||
|
(self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
||||||
|
|
||||||
|
|
||||||
|
def compute_alpha_t(self, t):
|
||||||
|
"""Compute coefficient of x1"""
|
||||||
|
alpha_t = self.log_mean_coeff(t)
|
||||||
|
alpha_t = th.exp(alpha_t)
|
||||||
|
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
||||||
|
return alpha_t, d_alpha_t
|
||||||
|
|
||||||
|
def compute_sigma_t(self, t):
|
||||||
|
"""Compute coefficient of x0"""
|
||||||
|
p_sigma_t = 2 * self.log_mean_coeff(t)
|
||||||
|
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
||||||
|
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
||||||
|
return sigma_t, d_sigma_t
|
||||||
|
|
||||||
|
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||||
|
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
||||||
|
return self.d_log_mean_coeff(t)
|
||||||
|
|
||||||
|
def compute_drift(self, x, t):
|
||||||
|
"""Compute the drift term of the SDE"""
|
||||||
|
t = expand_t_like_x(t, x)
|
||||||
|
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
||||||
|
return -0.5 * beta_t * x, beta_t / 2
|
||||||
|
|
||||||
|
|
||||||
|
class GVPCPlan(ICPlan):
|
||||||
|
def __init__(self, sigma=0.0):
|
||||||
|
super().__init__(sigma)
|
||||||
|
|
||||||
|
def compute_alpha_t(self, t):
|
||||||
|
"""Compute coefficient of x1"""
|
||||||
|
alpha_t = th.sin(t * np.pi / 2)
|
||||||
|
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
||||||
|
return alpha_t, d_alpha_t
|
||||||
|
|
||||||
|
def compute_sigma_t(self, t):
|
||||||
|
"""Compute coefficient of x0"""
|
||||||
|
sigma_t = th.cos(t * np.pi / 2)
|
||||||
|
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
||||||
|
return sigma_t, d_sigma_t
|
||||||
|
|
||||||
|
def compute_d_alpha_alpha_ratio_t(self, t):
|
||||||
|
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
||||||
|
return np.pi / (2 * th.tan(t * np.pi / 2))
|
||||||
534
hy3dshape/hy3dshape/models/diffusion/transport/transport.py
Normal file
534
hy3dshape/hy3dshape/models/diffusion/transport/transport.py
Normal file
@ -0,0 +1,534 @@
|
|||||||
|
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||||
|
# which is licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from . import path
|
||||||
|
from .utils import EasyDict, log_state, mean_flat
|
||||||
|
from .integrators import ode, sde
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(enum.Enum):
|
||||||
|
"""
|
||||||
|
Which type of output the model predicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NOISE = enum.auto() # the model predicts epsilon
|
||||||
|
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
||||||
|
VELOCITY = enum.auto() # the model predicts v(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PathType(enum.Enum):
|
||||||
|
"""
|
||||||
|
Which type of path to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LINEAR = enum.auto()
|
||||||
|
GVP = enum.auto()
|
||||||
|
VP = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class WeightType(enum.Enum):
|
||||||
|
"""
|
||||||
|
Which type of weighting to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NONE = enum.auto()
|
||||||
|
VELOCITY = enum.auto()
|
||||||
|
LIKELIHOOD = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Transport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_type,
|
||||||
|
path_type,
|
||||||
|
loss_type,
|
||||||
|
train_eps,
|
||||||
|
sample_eps,
|
||||||
|
train_sample_type = "uniform",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
path_options = {
|
||||||
|
PathType.LINEAR: path.ICPlan,
|
||||||
|
PathType.GVP: path.GVPCPlan,
|
||||||
|
PathType.VP: path.VPCPlan,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.loss_type = loss_type
|
||||||
|
self.model_type = model_type
|
||||||
|
self.path_sampler = path_options[path_type]()
|
||||||
|
self.train_eps = train_eps
|
||||||
|
self.sample_eps = sample_eps
|
||||||
|
self.train_sample_type = train_sample_type
|
||||||
|
if self.train_sample_type == "logit_normal":
|
||||||
|
self.mean = kwargs['mean']
|
||||||
|
self.std = kwargs['std']
|
||||||
|
self.shift_scale = kwargs['shift_scale']
|
||||||
|
print(f"using logit normal sample, shift scale is {self.shift_scale}")
|
||||||
|
|
||||||
|
def prior_logp(self, z):
|
||||||
|
'''
|
||||||
|
Standard multivariate normal prior
|
||||||
|
Assume z is batched
|
||||||
|
'''
|
||||||
|
shape = th.tensor(z.size())
|
||||||
|
N = th.prod(shape[1:])
|
||||||
|
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
||||||
|
return th.vmap(_fn)(z)
|
||||||
|
|
||||||
|
def check_interval(
|
||||||
|
self,
|
||||||
|
train_eps,
|
||||||
|
sample_eps,
|
||||||
|
*,
|
||||||
|
diffusion_form="SBDM",
|
||||||
|
sde=False,
|
||||||
|
reverse=False,
|
||||||
|
eval=False,
|
||||||
|
last_step_size=0.0,
|
||||||
|
):
|
||||||
|
t0 = 0
|
||||||
|
t1 = 1
|
||||||
|
eps = train_eps if not eval else sample_eps
|
||||||
|
if (type(self.path_sampler) in [path.VPCPlan]):
|
||||||
|
|
||||||
|
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
||||||
|
|
||||||
|
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
||||||
|
and (
|
||||||
|
self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
||||||
|
|
||||||
|
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
||||||
|
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
||||||
|
|
||||||
|
if reverse:
|
||||||
|
t0, t1 = 1 - t0, 1 - t1
|
||||||
|
|
||||||
|
return t0, t1
|
||||||
|
|
||||||
|
def sample(self, x1):
|
||||||
|
"""Sampling x0 & t based on shape of x1 (if needed)
|
||||||
|
Args:
|
||||||
|
x1 - data point; [batch, *dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
x0 = th.randn_like(x1)
|
||||||
|
if self.train_sample_type=="uniform":
|
||||||
|
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
||||||
|
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
||||||
|
t = t.to(x1)
|
||||||
|
elif self.train_sample_type=="logit_normal":
|
||||||
|
t = th.randn((x1.shape[0],)) * self.std + self.mean
|
||||||
|
t = t.to(x1)
|
||||||
|
t = 1/(1+th.exp(-t))
|
||||||
|
|
||||||
|
t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)
|
||||||
|
|
||||||
|
return t, x0, x1
|
||||||
|
|
||||||
|
def training_losses(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
x1,
|
||||||
|
model_kwargs=None
|
||||||
|
):
|
||||||
|
"""Loss for training the score model
|
||||||
|
Args:
|
||||||
|
- model: backbone model; could be score, noise, or velocity
|
||||||
|
- x1: datapoint
|
||||||
|
- model_kwargs: additional arguments for the model
|
||||||
|
"""
|
||||||
|
if model_kwargs == None:
|
||||||
|
model_kwargs = {}
|
||||||
|
|
||||||
|
t, x0, x1 = self.sample(x1)
|
||||||
|
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
||||||
|
model_output = model(xt, t, **model_kwargs)
|
||||||
|
B, *_, C = xt.shape
|
||||||
|
assert model_output.size() == (B, *xt.size()[1:-1], C)
|
||||||
|
|
||||||
|
terms = {}
|
||||||
|
terms['pred'] = model_output
|
||||||
|
if self.model_type == ModelType.VELOCITY:
|
||||||
|
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
||||||
|
else:
|
||||||
|
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
||||||
|
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
||||||
|
if self.loss_type in [WeightType.VELOCITY]:
|
||||||
|
weight = (drift_var / sigma_t) ** 2
|
||||||
|
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
||||||
|
weight = drift_var / (sigma_t ** 2)
|
||||||
|
elif self.loss_type in [WeightType.NONE]:
|
||||||
|
weight = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if self.model_type == ModelType.NOISE:
|
||||||
|
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
||||||
|
else:
|
||||||
|
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
||||||
|
|
||||||
|
return terms
|
||||||
|
|
||||||
|
def get_drift(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
"""member function for obtaining the drift of the probability flow ODE"""
|
||||||
|
|
||||||
|
def score_ode(x, t, model, **model_kwargs):
|
||||||
|
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
||||||
|
model_output = model(x, t, **model_kwargs)
|
||||||
|
return (-drift_mean + drift_var * model_output) # by change of variable
|
||||||
|
|
||||||
|
def noise_ode(x, t, model, **model_kwargs):
|
||||||
|
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
||||||
|
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
||||||
|
model_output = model(x, t, **model_kwargs)
|
||||||
|
score = model_output / -sigma_t
|
||||||
|
return (-drift_mean + drift_var * score)
|
||||||
|
|
||||||
|
def velocity_ode(x, t, model, **model_kwargs):
|
||||||
|
model_output = model(x, t, **model_kwargs)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
if self.model_type == ModelType.NOISE:
|
||||||
|
drift_fn = noise_ode
|
||||||
|
elif self.model_type == ModelType.SCORE:
|
||||||
|
drift_fn = score_ode
|
||||||
|
else:
|
||||||
|
drift_fn = velocity_ode
|
||||||
|
|
||||||
|
def body_fn(x, t, model, **model_kwargs):
|
||||||
|
model_output = drift_fn(x, t, model, **model_kwargs)
|
||||||
|
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
return body_fn
|
||||||
|
|
||||||
|
def get_score(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""member function for obtaining score of
|
||||||
|
x_t = alpha_t * x + sigma_t * eps"""
|
||||||
|
if self.model_type == ModelType.NOISE:
|
||||||
|
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \
|
||||||
|
self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
||||||
|
elif self.model_type == ModelType.SCORE:
|
||||||
|
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
||||||
|
elif self.model_type == ModelType.VELOCITY:
|
||||||
|
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,
|
||||||
|
t)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return score_fn
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
"""Sampler class for the transport model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport,
|
||||||
|
):
|
||||||
|
"""Constructor for a general sampler; supporting different sampling methods
|
||||||
|
Args:
|
||||||
|
- transport: an tranport object specify model prediction & interpolant type
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.transport = transport
|
||||||
|
self.drift = self.transport.get_drift()
|
||||||
|
self.score = self.transport.get_score()
|
||||||
|
|
||||||
|
def __get_sde_diffusion_and_drift(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
diffusion_form="SBDM",
|
||||||
|
diffusion_norm=1.0,
|
||||||
|
):
|
||||||
|
|
||||||
|
def diffusion_fn(x, t):
|
||||||
|
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
||||||
|
return diffusion
|
||||||
|
|
||||||
|
sde_drift = \
|
||||||
|
lambda x, t, model, **kwargs: \
|
||||||
|
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
||||||
|
|
||||||
|
sde_diffusion = diffusion_fn
|
||||||
|
|
||||||
|
return sde_drift, sde_diffusion
|
||||||
|
|
||||||
|
def __get_last_step(
|
||||||
|
self,
|
||||||
|
sde_drift,
|
||||||
|
*,
|
||||||
|
last_step,
|
||||||
|
last_step_size,
|
||||||
|
):
|
||||||
|
"""Get the last step function of the SDE solver"""
|
||||||
|
|
||||||
|
if last_step is None:
|
||||||
|
last_step_fn = \
|
||||||
|
lambda x, t, model, **model_kwargs: \
|
||||||
|
x
|
||||||
|
elif last_step == "Mean":
|
||||||
|
last_step_fn = \
|
||||||
|
lambda x, t, model, **model_kwargs: \
|
||||||
|
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
||||||
|
elif last_step == "Tweedie":
|
||||||
|
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
||||||
|
sigma = self.transport.path_sampler.compute_sigma_t
|
||||||
|
last_step_fn = \
|
||||||
|
lambda x, t, model, **model_kwargs: \
|
||||||
|
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,
|
||||||
|
**model_kwargs)
|
||||||
|
elif last_step == "Euler":
|
||||||
|
last_step_fn = \
|
||||||
|
lambda x, t, model, **model_kwargs: \
|
||||||
|
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return last_step_fn
|
||||||
|
|
||||||
|
def sample_sde(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_method="Euler",
|
||||||
|
diffusion_form="SBDM",
|
||||||
|
diffusion_norm=1.0,
|
||||||
|
last_step="Mean",
|
||||||
|
last_step_size=0.04,
|
||||||
|
num_steps=250,
|
||||||
|
):
|
||||||
|
"""returns a sampling function with given SDE settings
|
||||||
|
Args:
|
||||||
|
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
||||||
|
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
||||||
|
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
||||||
|
- last_step: type of the last step; default to identity
|
||||||
|
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
||||||
|
- num_steps: total integration step of SDE
|
||||||
|
"""
|
||||||
|
|
||||||
|
if last_step is None:
|
||||||
|
last_step_size = 0.0
|
||||||
|
|
||||||
|
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
||||||
|
diffusion_form=diffusion_form,
|
||||||
|
diffusion_norm=diffusion_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
t0, t1 = self.transport.check_interval(
|
||||||
|
self.transport.train_eps,
|
||||||
|
self.transport.sample_eps,
|
||||||
|
diffusion_form=diffusion_form,
|
||||||
|
sde=True,
|
||||||
|
eval=True,
|
||||||
|
reverse=False,
|
||||||
|
last_step_size=last_step_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
_sde = sde(
|
||||||
|
sde_drift,
|
||||||
|
sde_diffusion,
|
||||||
|
t0=t0,
|
||||||
|
t1=t1,
|
||||||
|
num_steps=num_steps,
|
||||||
|
sampler_type=sampling_method
|
||||||
|
)
|
||||||
|
|
||||||
|
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
||||||
|
|
||||||
|
def _sample(init, model, **model_kwargs):
|
||||||
|
xs = _sde.sample(init, model, **model_kwargs)
|
||||||
|
ts = th.ones(init.size(0), device=init.device) * t1
|
||||||
|
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
||||||
|
|
||||||
|
return xs
|
||||||
|
|
||||||
|
return _sample
|
||||||
|
|
||||||
|
def sample_ode(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_method="dopri5",
|
||||||
|
num_steps=50,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-3,
|
||||||
|
reverse=False,
|
||||||
|
):
|
||||||
|
"""returns a sampling function with given ODE settings
|
||||||
|
Args:
|
||||||
|
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||||
|
- num_steps:
|
||||||
|
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||||
|
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||||
|
- atol: absolute error tolerance for the solver
|
||||||
|
- rtol: relative error tolerance for the solver
|
||||||
|
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
||||||
|
"""
|
||||||
|
if reverse:
|
||||||
|
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
||||||
|
else:
|
||||||
|
drift = self.drift
|
||||||
|
|
||||||
|
t0, t1 = self.transport.check_interval(
|
||||||
|
self.transport.train_eps,
|
||||||
|
self.transport.sample_eps,
|
||||||
|
sde=False,
|
||||||
|
eval=True,
|
||||||
|
reverse=reverse,
|
||||||
|
last_step_size=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ode = ode(
|
||||||
|
drift=drift,
|
||||||
|
t0=t0,
|
||||||
|
t1=t1,
|
||||||
|
sampler_type=sampling_method,
|
||||||
|
num_steps=num_steps,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _ode.sample
|
||||||
|
|
||||||
|
def sample_ode_intermediate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_method="dopri5",
|
||||||
|
num_steps=50,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-3,
|
||||||
|
t=0.5,
|
||||||
|
reverse=False,
|
||||||
|
):
|
||||||
|
"""returns a sampling function with given ODE settings
|
||||||
|
Args:
|
||||||
|
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||||
|
- num_steps:
|
||||||
|
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||||
|
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||||
|
- atol: absolute error tolerance for the solver
|
||||||
|
- rtol: relative error tolerance for the solver
|
||||||
|
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
||||||
|
"""
|
||||||
|
if reverse:
|
||||||
|
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
||||||
|
else:
|
||||||
|
drift = self.drift
|
||||||
|
|
||||||
|
t0, t1 = self.transport.check_interval(
|
||||||
|
self.transport.train_eps,
|
||||||
|
self.transport.sample_eps,
|
||||||
|
sde=False,
|
||||||
|
eval=True,
|
||||||
|
reverse=reverse,
|
||||||
|
last_step_size=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ode = ode(
|
||||||
|
drift=drift,
|
||||||
|
t0=t,
|
||||||
|
t1=t1,
|
||||||
|
sampler_type=sampling_method,
|
||||||
|
num_steps=num_steps,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _ode.sample
|
||||||
|
|
||||||
|
def sample_ode_likelihood(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_method="dopri5",
|
||||||
|
num_steps=50,
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-3,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""returns a sampling function for calculating likelihood with given ODE settings
|
||||||
|
Args:
|
||||||
|
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
||||||
|
- num_steps:
|
||||||
|
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
||||||
|
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
||||||
|
- atol: absolute error tolerance for the solver
|
||||||
|
- rtol: relative error tolerance for the solver
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _likelihood_drift(x, t, model, **model_kwargs):
|
||||||
|
x, _ = x
|
||||||
|
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
||||||
|
t = th.ones_like(t) * (1 - t)
|
||||||
|
with th.enable_grad():
|
||||||
|
x.requires_grad = True
|
||||||
|
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
||||||
|
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
||||||
|
drift = self.drift(x, t, model, **model_kwargs)
|
||||||
|
return (-drift, logp_grad)
|
||||||
|
|
||||||
|
t0, t1 = self.transport.check_interval(
|
||||||
|
self.transport.train_eps,
|
||||||
|
self.transport.sample_eps,
|
||||||
|
sde=False,
|
||||||
|
eval=True,
|
||||||
|
reverse=False,
|
||||||
|
last_step_size=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ode = ode(
|
||||||
|
drift=_likelihood_drift,
|
||||||
|
t0=t0,
|
||||||
|
t1=t1,
|
||||||
|
sampler_type=sampling_method,
|
||||||
|
num_steps=num_steps,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sample_fn(x, model, **model_kwargs):
|
||||||
|
init_logp = th.zeros(x.size(0)).to(x)
|
||||||
|
input = (x, init_logp)
|
||||||
|
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
||||||
|
drift, delta_logp = drift[-1], delta_logp[-1]
|
||||||
|
prior_logp = self.transport.prior_logp(drift)
|
||||||
|
logp = prior_logp - delta_logp
|
||||||
|
return logp, drift
|
||||||
|
|
||||||
|
return _sample_fn
|
||||||
54
hy3dshape/hy3dshape/models/diffusion/transport/utils.py
Normal file
54
hy3dshape/hy3dshape/models/diffusion/transport/utils.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
||||||
|
# which is licensed under the MIT License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
|
||||||
|
class EasyDict:
|
||||||
|
|
||||||
|
def __init__(self, sub_dict):
|
||||||
|
for k, v in sub_dict.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def mean_flat(x):
|
||||||
|
"""
|
||||||
|
Take the mean over all non-batch dimensions.
|
||||||
|
"""
|
||||||
|
return th.mean(x, dim=list(range(1, len(x.size()))))
|
||||||
|
|
||||||
|
def log_state(state):
|
||||||
|
result = []
|
||||||
|
|
||||||
|
sorted_state = dict(sorted(state.items()))
|
||||||
|
for key, value in sorted_state.items():
|
||||||
|
# Check if the value is an instance of a class
|
||||||
|
if "<object" in str(value) or "object at" in str(value):
|
||||||
|
result.append(f"{key}: [{value.__class__.__name__}]")
|
||||||
|
else:
|
||||||
|
result.append(f"{key}: {value}")
|
||||||
|
|
||||||
|
return '\n'.join(result)
|
||||||
792
hy3dshape/hy3dshape/pipelines.py
Normal file
792
hy3dshape/hy3dshape/pipelines.py
Normal file
@ -0,0 +1,792 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
import yaml
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .models.autoencoders import ShapeVAE
|
||||||
|
from .models.autoencoders import SurfaceExtractors
|
||||||
|
from .utils import logger, synchronize_timer, smart_load_model
|
||||||
|
|
||||||
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
@synchronize_timer('Export to trimesh')
|
||||||
|
def export_to_trimesh(mesh_output):
|
||||||
|
if isinstance(mesh_output, list):
|
||||||
|
outputs = []
|
||||||
|
for mesh in mesh_output:
|
||||||
|
if mesh is None:
|
||||||
|
outputs.append(None)
|
||||||
|
else:
|
||||||
|
mesh.mesh_f = mesh.mesh_f[:, ::-1]
|
||||||
|
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
|
||||||
|
outputs.append(mesh_output)
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
|
||||||
|
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
|
||||||
|
return mesh_output
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_from_str(string, reload=False):
|
||||||
|
module, cls = string.rsplit(".", 1)
|
||||||
|
if reload:
|
||||||
|
module_imp = importlib.import_module(module)
|
||||||
|
importlib.reload(module_imp)
|
||||||
|
try:
|
||||||
|
obj = getattr(importlib.import_module(module, package=os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))), cls)
|
||||||
|
except:
|
||||||
|
obj = getattr(importlib.import_module(module, package=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( __file__ ))))), cls)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_from_config(config, **kwargs):
|
||||||
|
if "target" not in config:
|
||||||
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
|
cls = get_obj_from_str(config["target"])
|
||||||
|
params = config.get("params", dict())
|
||||||
|
kwargs.update(params)
|
||||||
|
instance = cls(**kwargs)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3DDiTPipeline:
|
||||||
|
model_cpu_offload_seq = "conditioner->model->vae"
|
||||||
|
_exclude_from_cpu_offload = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
|
||||||
|
def from_single_file(
|
||||||
|
cls,
|
||||||
|
ckpt_path,
|
||||||
|
config_path,
|
||||||
|
device='cuda',
|
||||||
|
dtype=torch.float16,
|
||||||
|
use_safetensors=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# load config
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# # load ckpt
|
||||||
|
# if use_safetensors:
|
||||||
|
# ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||||
|
# if not os.path.exists(ckpt_path):
|
||||||
|
# raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||||
|
# logger.info(f"Loading model from {ckpt_path}")
|
||||||
|
|
||||||
|
# if use_safetensors:
|
||||||
|
# # parse safetensors
|
||||||
|
# import safetensors.torch
|
||||||
|
# safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||||
|
# ckpt = {}
|
||||||
|
# for key, value in safetensors_ckpt.items():
|
||||||
|
# model_name = key.split('.')[0]
|
||||||
|
# new_key = key[len(model_name) + 1:]
|
||||||
|
# if model_name not in ckpt:
|
||||||
|
# ckpt[model_name] = {}
|
||||||
|
# ckpt[model_name][new_key] = value
|
||||||
|
# else:
|
||||||
|
# ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||||
|
|
||||||
|
ckpt = load_torch_file(ckpt_path)
|
||||||
|
# load model
|
||||||
|
model = instantiate_from_config(config['model'])
|
||||||
|
model.load_state_dict(ckpt['model'])
|
||||||
|
vae = instantiate_from_config(config['vae'])
|
||||||
|
vae.load_state_dict(ckpt['vae'], strict=False)
|
||||||
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
|
if 'conditioner' in ckpt:
|
||||||
|
conditioner.load_state_dict(ckpt['conditioner'])
|
||||||
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
|
scheduler = instantiate_from_config(config['scheduler'])
|
||||||
|
|
||||||
|
model_kwargs = dict(
|
||||||
|
vae=vae,
|
||||||
|
model=model,
|
||||||
|
scheduler=scheduler,
|
||||||
|
conditioner=conditioner,
|
||||||
|
image_processor=image_processor,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
model_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_path,
|
||||||
|
device='cuda',
|
||||||
|
dtype=torch.float16,
|
||||||
|
use_safetensors=False,
|
||||||
|
variant='fp16',
|
||||||
|
subfolder='hunyuan3d-dit-v2-1',
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
kwargs['from_pretrained_kwargs'] = dict(
|
||||||
|
model_path=model_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
variant=variant,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
config_path, ckpt_path = smart_load_model(
|
||||||
|
model_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
variant=variant
|
||||||
|
)
|
||||||
|
return cls.from_single_file(
|
||||||
|
ckpt_path,
|
||||||
|
config_path,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae,
|
||||||
|
model,
|
||||||
|
scheduler,
|
||||||
|
conditioner,
|
||||||
|
image_processor,
|
||||||
|
device='cuda',
|
||||||
|
dtype=torch.float16,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.vae = vae
|
||||||
|
self.model = model
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.conditioner = conditioner
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.to(device, dtype)
|
||||||
|
|
||||||
|
def compile(self):
|
||||||
|
self.vae = torch.compile(self.vae)
|
||||||
|
self.model = torch.compile(self.model)
|
||||||
|
self.conditioner = torch.compile(self.conditioner)
|
||||||
|
|
||||||
|
def enable_flashvdm(
|
||||||
|
self,
|
||||||
|
enabled: bool = True,
|
||||||
|
adaptive_kv_selection=True,
|
||||||
|
topk_mode='mean',
|
||||||
|
mc_algo='mc',
|
||||||
|
replace_vae=True,
|
||||||
|
):
|
||||||
|
if enabled:
|
||||||
|
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
||||||
|
turbo_vae_mapping = {
|
||||||
|
'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
|
||||||
|
'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
|
||||||
|
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini-turbo'),
|
||||||
|
}
|
||||||
|
model_name = model_path.split('/')[-1]
|
||||||
|
if replace_vae and model_name in turbo_vae_mapping:
|
||||||
|
model_path, subfolder = turbo_vae_mapping[model_name]
|
||||||
|
self.vae = ShapeVAE.from_pretrained(
|
||||||
|
model_path, subfolder=subfolder,
|
||||||
|
use_safetensors=self.kwargs['from_pretrained_kwargs']['use_safetensors'],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.vae.enable_flashvdm_decoder(
|
||||||
|
enabled=enabled,
|
||||||
|
adaptive_kv_selection=adaptive_kv_selection,
|
||||||
|
topk_mode=topk_mode,
|
||||||
|
mc_algo=mc_algo
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
||||||
|
vae_mapping = {
|
||||||
|
'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
|
||||||
|
'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
|
||||||
|
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini'),
|
||||||
|
}
|
||||||
|
model_name = model_path.split('/')[-1]
|
||||||
|
if model_name in vae_mapping:
|
||||||
|
model_path, subfolder = vae_mapping[model_name]
|
||||||
|
self.vae = ShapeVAE.from_pretrained(model_path, subfolder=subfolder)
|
||||||
|
self.vae.enable_flashvdm_decoder(enabled=False)
|
||||||
|
|
||||||
|
def to(self, device=None, dtype=None):
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtype = dtype
|
||||||
|
self.vae.to(dtype=dtype)
|
||||||
|
self.model.to(dtype=dtype)
|
||||||
|
self.conditioner.to(dtype=dtype)
|
||||||
|
if device is not None:
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.vae.to(device)
|
||||||
|
self.model.to(device)
|
||||||
|
self.conditioner.to(device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _execution_device(self):
|
||||||
|
r"""
|
||||||
|
Returns the device on which the pipeline's models will be executed. After calling
|
||||||
|
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||||
|
Accelerate's module hooks.
|
||||||
|
"""
|
||||||
|
for name, model in self.components.items():
|
||||||
|
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(model, "_hf_hook"):
|
||||||
|
return self.device
|
||||||
|
for module in model.modules():
|
||||||
|
if (
|
||||||
|
hasattr(module, "_hf_hook")
|
||||||
|
and hasattr(module._hf_hook, "execution_device")
|
||||||
|
and module._hf_hook.execution_device is not None
|
||||||
|
):
|
||||||
|
return torch.device(module._hf_hook.execution_device)
|
||||||
|
return self.device
|
||||||
|
|
||||||
|
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||||
|
r"""
|
||||||
|
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||||
|
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||||
|
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||||
|
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
gpu_id (`int`, *optional*):
|
||||||
|
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
|
||||||
|
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
|
||||||
|
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||||
|
default to "cuda".
|
||||||
|
"""
|
||||||
|
if self.model_cpu_offload_seq is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||||
|
from accelerate import cpu_offload_with_hook
|
||||||
|
else:
|
||||||
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||||
|
|
||||||
|
torch_device = torch.device(device)
|
||||||
|
device_index = torch_device.index
|
||||||
|
|
||||||
|
if gpu_id is not None and device_index is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
||||||
|
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of "
|
||||||
|
f"the device: `device`={torch_device.type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`)
|
||||||
|
# or default to previously set id or default to 0
|
||||||
|
self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
|
||||||
|
|
||||||
|
device_type = torch_device.type
|
||||||
|
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||||
|
|
||||||
|
if self.device.type != "cpu":
|
||||||
|
self.to("cpu")
|
||||||
|
device_mod = getattr(torch, self.device.type, None)
|
||||||
|
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
||||||
|
device_mod.empty_cache()
|
||||||
|
# otherwise we don't see the memory savings (but they probably exist)
|
||||||
|
|
||||||
|
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||||
|
|
||||||
|
self._all_hooks = []
|
||||||
|
hook = None
|
||||||
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||||
|
model = all_model_components.pop(model_str, None)
|
||||||
|
if not isinstance(model, torch.nn.Module):
|
||||||
|
continue
|
||||||
|
|
||||||
|
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
||||||
|
self._all_hooks.append(hook)
|
||||||
|
|
||||||
|
# CPU offload models that are not in the seq chain unless they are explicitly excluded
|
||||||
|
# these models will stay on CPU until maybe_free_model_hooks is called
|
||||||
|
# some models cannot be in the seq chain because they are iteratively called,
|
||||||
|
# such as controlnet
|
||||||
|
for name, model in all_model_components.items():
|
||||||
|
if not isinstance(model, torch.nn.Module):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name in self._exclude_from_cpu_offload:
|
||||||
|
model.to(device)
|
||||||
|
else:
|
||||||
|
_, hook = cpu_offload_with_hook(model, device)
|
||||||
|
self._all_hooks.append(hook)
|
||||||
|
|
||||||
|
def maybe_free_model_hooks(self):
|
||||||
|
r"""
|
||||||
|
Function that offloads all components, removes all model hooks that were added when using
|
||||||
|
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
|
||||||
|
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
||||||
|
functions correctly when applying enable_model_cpu_offload.
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
||||||
|
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
||||||
|
return
|
||||||
|
|
||||||
|
for hook in self._all_hooks:
|
||||||
|
# offload model and remove hook from model
|
||||||
|
hook.offload()
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# make sure the model is in the same state as before calling it
|
||||||
|
self.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
@synchronize_timer('Encode cond')
|
||||||
|
def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):
|
||||||
|
bsz = image.shape[0]
|
||||||
|
cond = self.conditioner(image=image, **additional_cond_inputs)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
|
||||||
|
|
||||||
|
if dual_guidance:
|
||||||
|
un_cond_drop_main = copy.deepcopy(un_cond)
|
||||||
|
un_cond_drop_main['additional'] = cond['additional']
|
||||||
|
|
||||||
|
def cat_recursive(a, b, c):
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
return torch.cat([a, b, c], dim=0).to(self.dtype)
|
||||||
|
out = {}
|
||||||
|
for k in a.keys():
|
||||||
|
out[k] = cat_recursive(a[k], b[k], c[k])
|
||||||
|
return out
|
||||||
|
|
||||||
|
cond = cat_recursive(cond, un_cond_drop_main, un_cond)
|
||||||
|
else:
|
||||||
|
def cat_recursive(a, b):
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
return torch.cat([a, b], dim=0).to(self.dtype)
|
||||||
|
out = {}
|
||||||
|
for k in a.keys():
|
||||||
|
out[k] = cat_recursive(a[k], b[k])
|
||||||
|
return out
|
||||||
|
|
||||||
|
cond = cat_recursive(cond, un_cond)
|
||||||
|
return cond
|
||||||
|
|
||||||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
def prepare_latents(self, batch_size, dtype, device, generator, latents=None):
|
||||||
|
shape = (batch_size, *self.vae.latent_shape)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def prepare_image(self, image, mask=None) -> dict:
|
||||||
|
if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):
|
||||||
|
outputs = {
|
||||||
|
'image': image,
|
||||||
|
'mask': mask
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
if isinstance(image, str) and not os.path.exists(image):
|
||||||
|
raise FileNotFoundError(f"Couldn't find image at path {image}")
|
||||||
|
|
||||||
|
if not isinstance(image, list):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for img in image:
|
||||||
|
output = self.image_processor(img)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
cond_input = {k: [] for k in outputs[0].keys()}
|
||||||
|
for output in outputs:
|
||||||
|
for key, value in output.items():
|
||||||
|
cond_input[key].append(value)
|
||||||
|
for key, value in cond_input.items():
|
||||||
|
if isinstance(value[0], torch.Tensor):
|
||||||
|
cond_input[key] = torch.cat(value, dim=0)
|
||||||
|
|
||||||
|
return cond_input
|
||||||
|
|
||||||
|
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||||
|
"""
|
||||||
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (`torch.Tensor`):
|
||||||
|
generate embedding vectors at these timesteps
|
||||||
|
embedding_dim (`int`, *optional*, defaults to 512):
|
||||||
|
dimension of the embeddings to generate
|
||||||
|
dtype:
|
||||||
|
data type of the generated embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||||
|
"""
|
||||||
|
assert len(w.shape) == 1
|
||||||
|
w = w * 1000.0
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||||
|
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
|
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||||
|
assert emb.shape == (w.shape[0], embedding_dim)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def set_surface_extractor(self, mc_algo):
|
||||||
|
if mc_algo is None:
|
||||||
|
return
|
||||||
|
logger.info('The parameters `mc_algo` is deprecated, and will be removed in future versions.\n'
|
||||||
|
'Please use: \n'
|
||||||
|
'from hy3dshape.models.autoencoders import SurfaceExtractors\n'
|
||||||
|
'pipeline.vae.surface_extractor = SurfaceExtractors[mc_algo]() instead\n')
|
||||||
|
if mc_algo not in SurfaceExtractors.keys():
|
||||||
|
raise ValueError(f"Unknown mc_algo {mc_algo}")
|
||||||
|
self.vae.surface_extractor = SurfaceExtractors[mc_algo]()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Union[str, List[str], Image.Image] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
timesteps: List[int] = None,
|
||||||
|
sigmas: List[float] = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
dual_guidance_scale: float = 10.5,
|
||||||
|
dual_guidance: bool = True,
|
||||||
|
generator=None,
|
||||||
|
box_v=1.01,
|
||||||
|
octree_resolution=384,
|
||||||
|
mc_level=-1 / 512,
|
||||||
|
num_chunks=8000,
|
||||||
|
mc_algo=None,
|
||||||
|
output_type: Optional[str] = "trimesh",
|
||||||
|
enable_pbar=True,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[List[trimesh.Trimesh]]:
|
||||||
|
callback = kwargs.pop("callback", None)
|
||||||
|
callback_steps = kwargs.pop("callback_steps", None)
|
||||||
|
|
||||||
|
self.set_surface_extractor(mc_algo)
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
do_classifier_free_guidance = guidance_scale >= 0 and \
|
||||||
|
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
||||||
|
dual_guidance = dual_guidance_scale >= 0 and dual_guidance
|
||||||
|
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
cond_inputs = self.prepare_image(image)
|
||||||
|
image = cond_inputs.pop('image')
|
||||||
|
|
||||||
|
cond = self.encode_cond(
|
||||||
|
image=image,
|
||||||
|
additional_cond_inputs=cond_inputs,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
dual_guidance=False,
|
||||||
|
)
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
t_dtype = torch.long
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||||
|
|
||||||
|
latents = self.prepare_latents(batch_size, dtype, device, generator)
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
guidance_cond = None
|
||||||
|
if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
|
||||||
|
logger.info('Using lcm guidance scale')
|
||||||
|
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
|
||||||
|
guidance_cond = self.get_guidance_scale_embedding(
|
||||||
|
guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
|
||||||
|
).to(device=device, dtype=latents.dtype)
|
||||||
|
with synchronize_timer('Diffusion Sampling'):
|
||||||
|
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
|
||||||
|
else:
|
||||||
|
latent_model_input = latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
|
||||||
|
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
||||||
|
noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
|
||||||
|
|
||||||
|
# no drop, drop clip, all drop
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
if dual_guidance:
|
||||||
|
noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
|
||||||
|
noise_pred = (
|
||||||
|
noise_pred_uncond
|
||||||
|
+ guidance_scale * (noise_pred_clip - noise_pred_dino)
|
||||||
|
+ dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
||||||
|
latents = outputs.prev_sample
|
||||||
|
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
|
callback(step_idx, t, outputs)
|
||||||
|
|
||||||
|
return self._export(
|
||||||
|
latents,
|
||||||
|
output_type,
|
||||||
|
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
output_type='trimesh',
|
||||||
|
box_v=1.01,
|
||||||
|
mc_level=0.0,
|
||||||
|
num_chunks=20000,
|
||||||
|
octree_resolution=256,
|
||||||
|
mc_algo='mc',
|
||||||
|
enable_pbar=True
|
||||||
|
):
|
||||||
|
if not output_type == "latent":
|
||||||
|
latents = 1. / self.vae.scale_factor * latents
|
||||||
|
latents = self.vae(latents)
|
||||||
|
outputs = self.vae.latents2mesh(
|
||||||
|
latents,
|
||||||
|
bounds=box_v,
|
||||||
|
mc_level=mc_level,
|
||||||
|
num_chunks=num_chunks,
|
||||||
|
octree_resolution=octree_resolution,
|
||||||
|
mc_algo=mc_algo,
|
||||||
|
enable_pbar=enable_pbar,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = latents
|
||||||
|
|
||||||
|
if output_type == 'trimesh':
|
||||||
|
outputs = export_to_trimesh(outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Union[str, List[str], Image.Image, dict, List[dict], torch.Tensor] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
timesteps: List[int] = None,
|
||||||
|
sigmas: List[float] = None,
|
||||||
|
eta: float = 0.0,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
generator=None,
|
||||||
|
box_v=1.01,
|
||||||
|
octree_resolution=384,
|
||||||
|
mc_level=0.0,
|
||||||
|
mc_algo=None,
|
||||||
|
num_chunks=8000,
|
||||||
|
output_type: Optional[str] = "trimesh",
|
||||||
|
enable_pbar=True,
|
||||||
|
mask = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[List[trimesh.Trimesh]]:
|
||||||
|
callback = kwargs.pop("callback", None)
|
||||||
|
callback_steps = kwargs.pop("callback_steps", None)
|
||||||
|
|
||||||
|
self.set_surface_extractor(mc_algo)
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
||||||
|
hasattr(self.model, 'guidance_embed') and
|
||||||
|
self.model.guidance_embed is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# print('image', type(image), 'mask', type(mask))
|
||||||
|
cond_inputs = self.prepare_image(image, mask)
|
||||||
|
image = cond_inputs.pop('image')
|
||||||
|
cond = self.encode_cond(
|
||||||
|
image=image,
|
||||||
|
additional_cond_inputs=cond_inputs,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
dual_guidance=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
# NOTE: this is slightly different from common usage, we start from 0.
|
||||||
|
sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
)
|
||||||
|
latents = self.prepare_latents(batch_size, dtype, device, generator)
|
||||||
|
|
||||||
|
guidance = None
|
||||||
|
if hasattr(self.model, 'guidance_embed') and \
|
||||||
|
self.model.guidance_embed is True:
|
||||||
|
guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)
|
||||||
|
# logger.info(f'Using guidance embed with scale {guidance_scale}')
|
||||||
|
|
||||||
|
with synchronize_timer('Diffusion Sampling'):
|
||||||
|
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
else:
|
||||||
|
latent_model_input = latents
|
||||||
|
|
||||||
|
# NOTE: we assume model get timesteps ranged from 0 to 1
|
||||||
|
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||||
|
timestep = timestep / self.scheduler.config.num_train_timesteps
|
||||||
|
noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
outputs = self.scheduler.step(noise_pred, t, latents)
|
||||||
|
latents = outputs.prev_sample
|
||||||
|
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
|
callback(step_idx, t, outputs)
|
||||||
|
|
||||||
|
return self._export(
|
||||||
|
latents,
|
||||||
|
output_type,
|
||||||
|
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
||||||
|
enable_pbar=enable_pbar,
|
||||||
|
)
|
||||||
202
hy3dshape/hy3dshape/postprocessors.py
Normal file
202
hy3dshape/hy3dshape/postprocessors.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pymeshlab
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
from .models.autoencoders import Latent2MeshOutput
|
||||||
|
from .utils import synchronize_timer
|
||||||
|
|
||||||
|
|
||||||
|
def load_mesh(path):
|
||||||
|
if path.endswith(".glb"):
|
||||||
|
mesh = trimesh.load(path)
|
||||||
|
else:
|
||||||
|
mesh = pymeshlab.MeshSet()
|
||||||
|
mesh.load_new_mesh(path)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
|
||||||
|
if max_facenum > mesh.current_mesh().face_number():
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
mesh.apply_filter(
|
||||||
|
"meshing_decimation_quadric_edge_collapse",
|
||||||
|
targetfacenum=max_facenum,
|
||||||
|
qualitythr=1.0,
|
||||||
|
preserveboundary=True,
|
||||||
|
boundaryweight=3,
|
||||||
|
preservenormal=True,
|
||||||
|
preservetopology=True,
|
||||||
|
autoclean=True
|
||||||
|
)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def remove_floater(mesh: pymeshlab.MeshSet):
|
||||||
|
mesh.apply_filter("compute_selection_by_small_disconnected_components_per_face",
|
||||||
|
nbfaceratio=0.005)
|
||||||
|
mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False)
|
||||||
|
mesh.apply_filter("meshing_remove_selected_vertices_and_faces")
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def pymeshlab2trimesh(mesh: pymeshlab.MeshSet):
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||||
|
mesh.save_current_mesh(temp_file.name)
|
||||||
|
mesh = trimesh.load(temp_file.name)
|
||||||
|
# 检查加载的对象类型
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
combined_mesh = trimesh.Trimesh()
|
||||||
|
# 如果是Scene,遍历所有的geometry并合并
|
||||||
|
for geom in mesh.geometry.values():
|
||||||
|
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
||||||
|
mesh = combined_mesh
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def trimesh2pymeshlab(mesh: trimesh.Trimesh):
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||||
|
if isinstance(mesh, trimesh.scene.Scene):
|
||||||
|
for idx, obj in enumerate(mesh.geometry.values()):
|
||||||
|
if idx == 0:
|
||||||
|
temp_mesh = obj
|
||||||
|
else:
|
||||||
|
temp_mesh = temp_mesh + obj
|
||||||
|
mesh = temp_mesh
|
||||||
|
mesh.export(temp_file.name)
|
||||||
|
mesh = pymeshlab.MeshSet()
|
||||||
|
mesh.load_new_mesh(temp_file.name)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def export_mesh(input, output):
|
||||||
|
if isinstance(input, pymeshlab.MeshSet):
|
||||||
|
mesh = output
|
||||||
|
elif isinstance(input, Latent2MeshOutput):
|
||||||
|
output = Latent2MeshOutput()
|
||||||
|
output.mesh_v = output.current_mesh().vertex_matrix()
|
||||||
|
output.mesh_f = output.current_mesh().face_matrix()
|
||||||
|
mesh = output
|
||||||
|
else:
|
||||||
|
mesh = pymeshlab2trimesh(output)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
|
||||||
|
if isinstance(mesh, str):
|
||||||
|
mesh = load_mesh(mesh)
|
||||||
|
elif isinstance(mesh, Latent2MeshOutput):
|
||||||
|
mesh = pymeshlab.MeshSet()
|
||||||
|
mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
|
||||||
|
mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
|
||||||
|
|
||||||
|
if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
|
||||||
|
mesh = trimesh2pymeshlab(mesh)
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
class FaceReducer:
|
||||||
|
@synchronize_timer('FaceReducer')
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||||
|
max_facenum: int = 40000
|
||||||
|
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
|
||||||
|
ms = import_mesh(mesh)
|
||||||
|
ms = reduce_face(ms, max_facenum=max_facenum)
|
||||||
|
mesh = export_mesh(mesh, ms)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
class FloaterRemover:
|
||||||
|
@synchronize_timer('FloaterRemover')
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||||
|
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
||||||
|
ms = import_mesh(mesh)
|
||||||
|
ms = remove_floater(ms)
|
||||||
|
mesh = export_mesh(mesh, ms)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
class DegenerateFaceRemover:
|
||||||
|
@synchronize_timer('DegenerateFaceRemover')
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
||||||
|
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
||||||
|
ms = import_mesh(mesh)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
||||||
|
ms.save_current_mesh(temp_file.name)
|
||||||
|
ms = pymeshlab.MeshSet()
|
||||||
|
ms.load_new_mesh(temp_file.name)
|
||||||
|
|
||||||
|
mesh = export_mesh(mesh, ms)
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def mesh_normalize(mesh):
|
||||||
|
"""
|
||||||
|
Normalize mesh vertices to sphere
|
||||||
|
"""
|
||||||
|
scale_factor = 1.2
|
||||||
|
vtx_pos = np.asarray(mesh.vertices)
|
||||||
|
max_bb = (vtx_pos - 0).max(0)[0]
|
||||||
|
min_bb = (vtx_pos - 0).min(0)[0]
|
||||||
|
|
||||||
|
center = (max_bb + min_bb) / 2
|
||||||
|
|
||||||
|
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
|
||||||
|
|
||||||
|
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
|
||||||
|
mesh.vertices = vtx_pos
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
class MeshSimplifier:
|
||||||
|
def __init__(self, executable: str = None):
|
||||||
|
if executable is None:
|
||||||
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
|
||||||
|
self.executable = executable
|
||||||
|
|
||||||
|
@synchronize_timer('MeshSimplifier')
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
mesh: Union[trimesh.Trimesh],
|
||||||
|
) -> Union[trimesh.Trimesh]:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
|
||||||
|
mesh.export(temp_input.name)
|
||||||
|
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
|
||||||
|
ms = trimesh.load(temp_output.name, process=False)
|
||||||
|
if isinstance(ms, trimesh.Scene):
|
||||||
|
combined_mesh = trimesh.Trimesh()
|
||||||
|
for geom in ms.geometry.values():
|
||||||
|
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
||||||
|
ms = combined_mesh
|
||||||
|
ms = mesh_normalize(ms)
|
||||||
|
return ms
|
||||||
167
hy3dshape/hy3dshape/preprocessors.py
Normal file
167
hy3dshape/hy3dshape/preprocessors.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def array_to_tensor(np_array):
|
||||||
|
image_pt = torch.tensor(np_array).float()
|
||||||
|
image_pt = image_pt / 255 * 2 - 1
|
||||||
|
image_pt = rearrange(image_pt, "h w c -> c h w")
|
||||||
|
image_pts = repeat(image_pt, "c h w -> b c h w", b=1)
|
||||||
|
return image_pts
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProcessorV2:
|
||||||
|
def __init__(self, size=512, border_ratio=None):
|
||||||
|
self.size = size
|
||||||
|
self.border_ratio = border_ratio
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def recenter(image, border_ratio: float = 0.2):
|
||||||
|
""" recenter an image to leave some empty space at the image border.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (ndarray): input image, float/uint8 [H, W, 3/4]
|
||||||
|
mask (ndarray): alpha mask, bool [H, W]
|
||||||
|
border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: output image, float/uint8 [H, W, 3/4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if image.shape[-1] == 4:
|
||||||
|
mask = image[..., 3]
|
||||||
|
else:
|
||||||
|
mask = np.ones_like(image[..., 0:1]) * 255
|
||||||
|
image = np.concatenate([image, mask], axis=-1)
|
||||||
|
mask = mask[..., 0]
|
||||||
|
|
||||||
|
H, W, C = image.shape
|
||||||
|
|
||||||
|
size = max(H, W)
|
||||||
|
result = np.zeros((size, size, C), dtype=np.uint8)
|
||||||
|
|
||||||
|
coords = np.nonzero(mask)
|
||||||
|
x_min, x_max = coords[0].min(), coords[0].max()
|
||||||
|
y_min, y_max = coords[1].min(), coords[1].max()
|
||||||
|
h = x_max - x_min
|
||||||
|
w = y_max - y_min
|
||||||
|
if h == 0 or w == 0:
|
||||||
|
raise ValueError('input image is empty')
|
||||||
|
desired_size = int(size * (1 - border_ratio))
|
||||||
|
scale = desired_size / max(h, w)
|
||||||
|
h2 = int(h * scale)
|
||||||
|
w2 = int(w * scale)
|
||||||
|
x2_min = (size - h2) // 2
|
||||||
|
x2_max = x2_min + h2
|
||||||
|
|
||||||
|
y2_min = (size - w2) // 2
|
||||||
|
y2_max = y2_min + w2
|
||||||
|
|
||||||
|
result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2),
|
||||||
|
interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
|
||||||
|
|
||||||
|
mask = result[..., 3:].astype(np.float32) / 255
|
||||||
|
result = result[..., :3] * mask + bg * (1 - mask)
|
||||||
|
|
||||||
|
mask = mask * 255
|
||||||
|
result = result.clip(0, 255).astype(np.uint8)
|
||||||
|
mask = mask.clip(0, 255).astype(np.uint8)
|
||||||
|
return result, mask
|
||||||
|
|
||||||
|
def load_image(self, image, border_ratio=0.15, to_tensor=True):
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
||||||
|
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
image = image.convert("RGBA")
|
||||||
|
image = np.asarray(image)
|
||||||
|
image, mask = self.recenter(image, border_ratio=border_ratio)
|
||||||
|
|
||||||
|
image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
||||||
|
mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = mask[..., np.newaxis]
|
||||||
|
|
||||||
|
if to_tensor:
|
||||||
|
image = array_to_tensor(image)
|
||||||
|
mask = array_to_tensor(mask)
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||||
|
if self.border_ratio is not None:
|
||||||
|
border_ratio = self.border_ratio
|
||||||
|
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||||
|
outputs = {
|
||||||
|
'image': image,
|
||||||
|
'mask': mask
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MVImageProcessorV2(ImageProcessorV2):
|
||||||
|
"""
|
||||||
|
view order: front, front clockwise 90, back, front clockwise 270
|
||||||
|
"""
|
||||||
|
return_view_idx = True
|
||||||
|
|
||||||
|
def __init__(self, size=512, border_ratio=None):
|
||||||
|
super().__init__(size, border_ratio)
|
||||||
|
self.view2idx = {
|
||||||
|
'front': 0,
|
||||||
|
'left': 1,
|
||||||
|
'back': 2,
|
||||||
|
'right': 3
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
|
||||||
|
if self.border_ratio is not None:
|
||||||
|
border_ratio = self.border_ratio
|
||||||
|
|
||||||
|
images = []
|
||||||
|
masks = []
|
||||||
|
view_idxs = []
|
||||||
|
for idx, (view_tag, image) in enumerate(image_dict.items()):
|
||||||
|
view_idxs.append(self.view2idx[view_tag])
|
||||||
|
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
||||||
|
images.append(image)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
zipped_lists = zip(view_idxs, images, masks)
|
||||||
|
sorted_zipped_lists = sorted(zipped_lists)
|
||||||
|
view_idxs, images, masks = zip(*sorted_zipped_lists)
|
||||||
|
|
||||||
|
image = torch.cat(images, 0).unsqueeze(0)
|
||||||
|
mask = torch.cat(masks, 0).unsqueeze(0)
|
||||||
|
outputs = {
|
||||||
|
'image': image,
|
||||||
|
'mask': mask,
|
||||||
|
'view_idxs': view_idxs
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_PROCESSORS = {
|
||||||
|
"v2": ImageProcessorV2,
|
||||||
|
'mv_v2': MVImageProcessorV2,
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_IMAGEPROCESSOR = 'v2'
|
||||||
25
hy3dshape/hy3dshape/rembg.py
Normal file
25
hy3dshape/hy3dshape/rembg.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from rembg import remove, new_session
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundRemover():
|
||||||
|
def __init__(self):
|
||||||
|
self.session = new_session()
|
||||||
|
|
||||||
|
def __call__(self, image: Image.Image):
|
||||||
|
output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])
|
||||||
|
return output
|
||||||
480
hy3dshape/hy3dshape/schedulers.py
Normal file
480
hy3dshape/hy3dshape/schedulers.py
Normal file
@ -0,0 +1,480 @@
|
|||||||
|
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
from diffusers.utils import BaseOutput, logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's `step` function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
|
||||||
|
|
||||||
|
Euler scheduler.
|
||||||
|
|
||||||
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all schedulers such as loading and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps (`int`, defaults to 1000):
|
||||||
|
The number of diffusion steps to train the model.
|
||||||
|
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||||
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||||
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
|
shift (`float`, defaults to 1.0):
|
||||||
|
The shift value for the timestep schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_compatibles = []
|
||||||
|
order = 1
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
shift: float = 1.0,
|
||||||
|
use_dynamic_shifting=False,
|
||||||
|
):
|
||||||
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()
|
||||||
|
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||||
|
|
||||||
|
sigmas = timesteps / num_train_timesteps
|
||||||
|
if not use_dynamic_shifting:
|
||||||
|
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
|
||||||
|
self.timesteps = sigmas * num_train_timesteps
|
||||||
|
|
||||||
|
self._step_index = None
|
||||||
|
self._begin_index = None
|
||||||
|
|
||||||
|
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
self.sigma_min = self.sigmas[-1].item()
|
||||||
|
self.sigma_max = self.sigmas[0].item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def step_index(self):
|
||||||
|
"""
|
||||||
|
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||||
|
"""
|
||||||
|
return self._step_index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def begin_index(self):
|
||||||
|
"""
|
||||||
|
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||||
|
"""
|
||||||
|
return self._begin_index
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||||
|
def set_begin_index(self, begin_index: int = 0):
|
||||||
|
"""
|
||||||
|
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_index (`int`):
|
||||||
|
The begin index for the scheduler.
|
||||||
|
"""
|
||||||
|
self._begin_index = begin_index
|
||||||
|
|
||||||
|
def scale_noise(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
noise: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Forward process in flow-matching
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The input sample.
|
||||||
|
timestep (`int`, *optional*):
|
||||||
|
The current timestep in the diffusion chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`:
|
||||||
|
A scaled input sample.
|
||||||
|
"""
|
||||||
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
|
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||||
|
|
||||||
|
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||||
|
# mps does not support float64
|
||||||
|
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||||
|
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
schedule_timesteps = self.timesteps.to(sample.device)
|
||||||
|
timestep = timestep.to(sample.device)
|
||||||
|
|
||||||
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
|
if self.begin_index is None:
|
||||||
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timestep.shape[0]
|
||||||
|
else:
|
||||||
|
# add noise is called before first denoising step to create initial latent(img2img)
|
||||||
|
step_indices = [self.begin_index] * timestep.shape[0]
|
||||||
|
|
||||||
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
while len(sigma.shape) < len(sample.shape):
|
||||||
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
|
||||||
|
sample = sigma * noise + (1.0 - sigma) * sample
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def _sigma_to_t(self, sigma):
|
||||||
|
return sigma * self.config.num_train_timesteps
|
||||||
|
|
||||||
|
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
def set_timesteps(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int = None,
|
||||||
|
device: Union[str, torch.device] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
mu: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.config.use_dynamic_shifting and mu is None:
|
||||||
|
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||||
|
|
||||||
|
if sigmas is None:
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
timesteps = np.linspace(
|
||||||
|
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
sigmas = timesteps / self.config.num_train_timesteps
|
||||||
|
|
||||||
|
if self.config.use_dynamic_shifting:
|
||||||
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||||
|
else:
|
||||||
|
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||||
|
|
||||||
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||||
|
timesteps = sigmas * self.config.num_train_timesteps
|
||||||
|
|
||||||
|
self.timesteps = timesteps.to(device=device)
|
||||||
|
self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
||||||
|
|
||||||
|
self._step_index = None
|
||||||
|
self._begin_index = None
|
||||||
|
|
||||||
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
|
if schedule_timesteps is None:
|
||||||
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
|
indices = (schedule_timesteps == timestep).nonzero()
|
||||||
|
|
||||||
|
# The sigma index that is taken for the **very** first `step`
|
||||||
|
# is always the second index (or the last index if there is only 1)
|
||||||
|
# This way we can ensure we don't accidentally skip a sigma in
|
||||||
|
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||||
|
pos = 1 if len(indices) > 1 else 0
|
||||||
|
|
||||||
|
return indices[pos].item()
|
||||||
|
|
||||||
|
def _init_step_index(self, timestep):
|
||||||
|
if self.begin_index is None:
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
self._step_index = self.index_for_timestep(timestep)
|
||||||
|
else:
|
||||||
|
self._step_index = self._begin_index
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
s_churn: float = 0.0,
|
||||||
|
s_tmin: float = 0.0,
|
||||||
|
s_tmax: float = float("inf"),
|
||||||
|
s_noise: float = 1.0,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||||
|
process from the learned model outputs (most often the predicted noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.FloatTensor`):
|
||||||
|
The direct output from learned diffusion model.
|
||||||
|
timestep (`float`):
|
||||||
|
The current discrete timestep in the diffusion chain.
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
A current instance of a sample created by the diffusion process.
|
||||||
|
s_churn (`float`):
|
||||||
|
s_tmin (`float`):
|
||||||
|
s_tmax (`float`):
|
||||||
|
s_noise (`float`, defaults to 1.0):
|
||||||
|
Scaling factor for noise added to the sample.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator.
|
||||||
|
return_dict (`bool`):
|
||||||
|
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||||
|
tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||||
|
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||||
|
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(timestep, int)
|
||||||
|
or isinstance(timestep, torch.IntTensor)
|
||||||
|
or isinstance(timestep, torch.LongTensor)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||||
|
" one of the `scheduler.timesteps` as a timestep."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.step_index is None:
|
||||||
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
|
# Upcast to avoid precision issues when computing prev_sample
|
||||||
|
sample = sample.to(torch.float32)
|
||||||
|
|
||||||
|
sigma = self.sigmas[self.step_index]
|
||||||
|
sigma_next = self.sigmas[self.step_index + 1]
|
||||||
|
|
||||||
|
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
|
|
||||||
|
# Cast sample back to model compatible dtype
|
||||||
|
prev_sample = prev_sample.to(model_output.dtype)
|
||||||
|
|
||||||
|
# upon completion increase step index by one
|
||||||
|
self._step_index += 1
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample,)
|
||||||
|
|
||||||
|
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||||
|
prev_sample: torch.FloatTensor
|
||||||
|
pred_original_sample: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
_compatibles = []
|
||||||
|
order = 1
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
pcm_timesteps: int = 50,
|
||||||
|
):
|
||||||
|
sigmas = np.linspace(0, 1, num_train_timesteps)
|
||||||
|
step_ratio = num_train_timesteps // pcm_timesteps
|
||||||
|
|
||||||
|
euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
|
||||||
|
euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
|
||||||
|
|
||||||
|
self.euler_timesteps = euler_timesteps
|
||||||
|
self.sigmas = sigmas[self.euler_timesteps]
|
||||||
|
self.sigmas = torch.from_numpy((self.sigmas.copy())).to(dtype=torch.float32)
|
||||||
|
self.timesteps = self.sigmas * num_train_timesteps
|
||||||
|
self._step_index = None
|
||||||
|
self._begin_index = None
|
||||||
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
|
@property
|
||||||
|
def step_index(self):
|
||||||
|
"""
|
||||||
|
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||||
|
"""
|
||||||
|
return self._step_index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def begin_index(self):
|
||||||
|
"""
|
||||||
|
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||||
|
"""
|
||||||
|
return self._begin_index
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||||
|
def set_begin_index(self, begin_index: int = 0):
|
||||||
|
"""
|
||||||
|
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_index (`int`):
|
||||||
|
The begin index for the scheduler.
|
||||||
|
"""
|
||||||
|
self._begin_index = begin_index
|
||||||
|
|
||||||
|
def _sigma_to_t(self, sigma):
|
||||||
|
return sigma * self.config.num_train_timesteps
|
||||||
|
|
||||||
|
def set_timesteps(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int = None,
|
||||||
|
device: Union[str, torch.device] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
"""
|
||||||
|
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
|
||||||
|
inference_indices = np.linspace(
|
||||||
|
0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
|
||||||
|
)
|
||||||
|
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||||||
|
inference_indices = torch.from_numpy(inference_indices).long()
|
||||||
|
|
||||||
|
self.sigmas_ = self.sigmas[inference_indices]
|
||||||
|
timesteps = self.sigmas_ * self.config.num_train_timesteps
|
||||||
|
self.timesteps = timesteps.to(device=device)
|
||||||
|
self.sigmas_ = torch.cat(
|
||||||
|
[self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._step_index = None
|
||||||
|
self._begin_index = None
|
||||||
|
|
||||||
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
|
if schedule_timesteps is None:
|
||||||
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
|
indices = (schedule_timesteps == timestep).nonzero()
|
||||||
|
|
||||||
|
# The sigma index that is taken for the **very** first `step`
|
||||||
|
# is always the second index (or the last index if there is only 1)
|
||||||
|
# This way we can ensure we don't accidentally skip a sigma in
|
||||||
|
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||||
|
pos = 1 if len(indices) > 1 else 0
|
||||||
|
|
||||||
|
return indices[pos].item()
|
||||||
|
|
||||||
|
def _init_step_index(self, timestep):
|
||||||
|
if self.begin_index is None:
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
self._step_index = self.index_for_timestep(timestep)
|
||||||
|
else:
|
||||||
|
self._step_index = self._begin_index
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||||
|
if (
|
||||||
|
isinstance(timestep, int)
|
||||||
|
or isinstance(timestep, torch.IntTensor)
|
||||||
|
or isinstance(timestep, torch.LongTensor)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||||
|
" one of the `scheduler.timesteps` as a timestep."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.step_index is None:
|
||||||
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
|
sample = sample.to(torch.float32)
|
||||||
|
|
||||||
|
sigma = self.sigmas_[self.step_index]
|
||||||
|
sigma_next = self.sigmas_[self.step_index + 1]
|
||||||
|
|
||||||
|
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||||
|
prev_sample = prev_sample.to(model_output.dtype)
|
||||||
|
|
||||||
|
pred_original_sample = sample + (1.0 - sigma) * model_output
|
||||||
|
pred_original_sample = pred_original_sample.to(model_output.dtype)
|
||||||
|
|
||||||
|
self._step_index += 1
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample,)
|
||||||
|
|
||||||
|
return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
|
||||||
|
pred_original_sample=pred_original_sample)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
||||||
234
hy3dshape/hy3dshape/surface_loaders.py
Normal file
234
hy3dshape/hy3dshape/surface_loaders.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_mesh(mesh, scale=0.9999):
|
||||||
|
"""
|
||||||
|
Normalize the mesh to fit inside a centered cube with a specified scale.
|
||||||
|
|
||||||
|
The mesh is translated so that its bounding box center is at the origin,
|
||||||
|
then uniformly scaled so that the longest side of the bounding box fits within [-scale, scale].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh to normalize.
|
||||||
|
scale (float, optional): Scaling factor to slightly shrink the mesh inside the unit cube. Default is 0.9999.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
trimesh.Trimesh: The normalized mesh with applied translation and scaling.
|
||||||
|
"""
|
||||||
|
bbox = mesh.bounds
|
||||||
|
center = (bbox[1] + bbox[0]) / 2
|
||||||
|
scale_ = (bbox[1] - bbox[0]).max()
|
||||||
|
|
||||||
|
mesh.apply_translation(-center)
|
||||||
|
mesh.apply_scale(1 / scale_ * 2 * scale)
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def sample_pointcloud(mesh, num=200000):
|
||||||
|
"""
|
||||||
|
Sample points uniformly from the surface of the mesh along with their corresponding face normals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh to sample from.
|
||||||
|
num (int, optional): Number of points to sample. Default is 200000.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
- points: Sampled points as a float tensor of shape (num, 3).
|
||||||
|
- normals: Corresponding normals as a float tensor of shape (num, 3).
|
||||||
|
"""
|
||||||
|
points, face_idx = mesh.sample(num, return_index=True)
|
||||||
|
normals = mesh.face_normals[face_idx]
|
||||||
|
points = torch.from_numpy(points.astype(np.float32))
|
||||||
|
normals = torch.from_numpy(normals.astype(np.float32))
|
||||||
|
return points, normals
|
||||||
|
|
||||||
|
|
||||||
|
def load_surface(mesh, num_points=8192):
|
||||||
|
"""
|
||||||
|
Normalize the mesh, sample points and normals from its surface, and randomly select a subset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh to process.
|
||||||
|
num_points (int, optional): Number of points to randomly select
|
||||||
|
from the sampled surface points. Default is 8192.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, trimesh.Trimesh]:
|
||||||
|
- surface: Tensor of shape (1, num_points, 6), concatenating points and normals.
|
||||||
|
- mesh: The normalized mesh.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mesh = normalize_mesh(mesh, scale=0.98)
|
||||||
|
surface, normal = sample_pointcloud(mesh)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(seed=0)
|
||||||
|
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
||||||
|
surface = torch.FloatTensor(surface[ind])
|
||||||
|
normal = torch.FloatTensor(normal[ind])
|
||||||
|
|
||||||
|
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)
|
||||||
|
|
||||||
|
return surface, mesh
|
||||||
|
|
||||||
|
|
||||||
|
def sharp_sample_pointcloud(mesh, num=16384):
|
||||||
|
"""
|
||||||
|
Sample points and normals preferentially from sharp edges of the mesh.
|
||||||
|
|
||||||
|
Sharp edges are detected based on the angle between vertex normals and face normals.
|
||||||
|
Points are sampled along these edges proportionally to edge length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): Input mesh to sample from.
|
||||||
|
num (int, optional): Number of points to sample from sharp edges. Default is 16384.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[np.ndarray, np.ndarray]:
|
||||||
|
- samples: Sampled points along sharp edges, shape (num, 3).
|
||||||
|
- normals: Corresponding interpolated normals, shape (num, 3).
|
||||||
|
"""
|
||||||
|
V = mesh.vertices
|
||||||
|
N = mesh.face_normals
|
||||||
|
VN = mesh.vertex_normals
|
||||||
|
F = mesh.faces
|
||||||
|
VN2 = np.ones(V.shape[0])
|
||||||
|
for i in range(3):
|
||||||
|
dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1)
|
||||||
|
VN2[F[:, i]] = np.min(dot, axis=-1)
|
||||||
|
|
||||||
|
sharp_mask = VN2 < 0.985
|
||||||
|
# collect edge
|
||||||
|
edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2]))
|
||||||
|
edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0]))
|
||||||
|
sharp_edge = ((sharp_mask[edge_a] * sharp_mask[edge_b]))
|
||||||
|
edge_a = edge_a[sharp_edge > 0]
|
||||||
|
edge_b = edge_b[sharp_edge > 0]
|
||||||
|
|
||||||
|
sharp_verts_a = V[edge_a]
|
||||||
|
sharp_verts_b = V[edge_b]
|
||||||
|
sharp_verts_an = VN[edge_a]
|
||||||
|
sharp_verts_bn = VN[edge_b]
|
||||||
|
|
||||||
|
weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)
|
||||||
|
weights /= np.sum(weights)
|
||||||
|
|
||||||
|
random_number = np.random.rand(num)
|
||||||
|
w = np.random.rand(num, 1)
|
||||||
|
index = np.searchsorted(weights.cumsum(), random_number)
|
||||||
|
samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]
|
||||||
|
normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index]
|
||||||
|
return samples, normals
|
||||||
|
|
||||||
|
|
||||||
|
def load_surface_sharpegde(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True):
|
||||||
|
try:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||||
|
except Exception as err:
|
||||||
|
mesh_full = trimesh.util.concatenate(mesh)
|
||||||
|
mesh_full = normalize_mesh(mesh_full)
|
||||||
|
|
||||||
|
origin_num = mesh_full.faces.shape[0]
|
||||||
|
original_vertices = mesh_full.vertices
|
||||||
|
original_faces = mesh_full.faces
|
||||||
|
|
||||||
|
mesh = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[:origin_num])
|
||||||
|
mesh_fill = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[origin_num:])
|
||||||
|
area = mesh.area
|
||||||
|
area_fill = mesh_fill.area
|
||||||
|
sample_num = 499712 // 2
|
||||||
|
num_fill = int(sample_num * (area_fill / (area + area_fill)))
|
||||||
|
num = sample_num - num_fill
|
||||||
|
|
||||||
|
random_surface, random_normal = sample_pointcloud(mesh, num=num)
|
||||||
|
if num_fill == 0:
|
||||||
|
random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3))
|
||||||
|
else:
|
||||||
|
random_surface_fill, random_normal_fill = sample_pointcloud(mesh_fill, num=num_fill)
|
||||||
|
random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num)
|
||||||
|
|
||||||
|
# save_surface
|
||||||
|
surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16)
|
||||||
|
surface_fill = np.concatenate((random_surface_fill, random_normal_fill), axis=1).astype(np.float16)
|
||||||
|
sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype(np.float16)
|
||||||
|
surface = np.concatenate((surface, surface_fill), axis=0)
|
||||||
|
if sharpedge_flag:
|
||||||
|
sharpedge_label = np.zeros((surface.shape[0], 1))
|
||||||
|
surface = np.concatenate((surface, sharpedge_label), axis=1)
|
||||||
|
sharpedge_label = np.ones((sharp_surface.shape[0], 1))
|
||||||
|
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
||||||
|
surface = torch.FloatTensor(surface[ind])
|
||||||
|
ind = rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)
|
||||||
|
sharp_surface = torch.FloatTensor(sharp_surface[ind])
|
||||||
|
|
||||||
|
return torch.cat([surface, sharp_surface], dim=0).unsqueeze(0), mesh_full
|
||||||
|
|
||||||
|
|
||||||
|
class SurfaceLoader:
|
||||||
|
def __init__(self, num_points=8192):
|
||||||
|
self.num_points = num_points
|
||||||
|
|
||||||
|
def __call__(self, mesh_or_mesh_path, num_points=None):
|
||||||
|
if num_points is None:
|
||||||
|
num_points = self.num_points
|
||||||
|
|
||||||
|
mesh = mesh_or_mesh_path
|
||||||
|
if isinstance(mesh, str):
|
||||||
|
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
||||||
|
if isinstance(mesh, trimesh.scene.Scene):
|
||||||
|
for idx, obj in enumerate(mesh.geometry.values()):
|
||||||
|
if idx == 0:
|
||||||
|
temp_mesh = obj
|
||||||
|
else:
|
||||||
|
temp_mesh = temp_mesh + obj
|
||||||
|
mesh = temp_mesh
|
||||||
|
surface, mesh = load_surface(mesh, num_points=num_points)
|
||||||
|
return surface
|
||||||
|
|
||||||
|
|
||||||
|
class SharpEdgeSurfaceLoader:
|
||||||
|
def __init__(self, num_uniform_points=8192, num_sharp_points=8192, **kwargs):
|
||||||
|
self.num_uniform_points = num_uniform_points
|
||||||
|
self.num_sharp_points = num_sharp_points
|
||||||
|
self.num_points = num_uniform_points + num_sharp_points
|
||||||
|
|
||||||
|
def __call__(self, mesh_or_mesh_path, num_uniform_points=None, num_sharp_points=None):
|
||||||
|
if num_uniform_points is None:
|
||||||
|
num_uniform_points = self.num_uniform_points
|
||||||
|
if num_sharp_points is None:
|
||||||
|
num_sharp_points = self.num_sharp_points
|
||||||
|
|
||||||
|
mesh = mesh_or_mesh_path
|
||||||
|
if isinstance(mesh, str):
|
||||||
|
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
||||||
|
if isinstance(mesh, trimesh.scene.Scene):
|
||||||
|
for idx, obj in enumerate(mesh.geometry.values()):
|
||||||
|
if idx == 0:
|
||||||
|
temp_mesh = obj
|
||||||
|
else:
|
||||||
|
temp_mesh = temp_mesh + obj
|
||||||
|
mesh = temp_mesh
|
||||||
|
surface, mesh = load_surface_sharpegde(mesh, num_points=num_uniform_points, num_sharp_points=num_sharp_points)
|
||||||
|
return surface
|
||||||
5
hy3dshape/hy3dshape/utils/__init__.py
Normal file
5
hy3dshape/hy3dshape/utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from .misc import get_config_from_file
|
||||||
|
from .misc import instantiate_from_config
|
||||||
|
from .utils import get_logger, logger, synchronize_timer, smart_load_model
|
||||||
76
hy3dshape/hy3dshape/utils/ema.py
Normal file
76
hy3dshape/hy3dshape/utils/ema.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class LitEma(nn.Module):
|
||||||
|
def __init__(self, model, decay=0.9999, use_num_updates=True):
|
||||||
|
super().__init__()
|
||||||
|
if decay < 0.0 or decay > 1.0:
|
||||||
|
raise ValueError('Decay must be between 0 and 1')
|
||||||
|
|
||||||
|
self.m_name2s_name = {}
|
||||||
|
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||||
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
|
||||||
|
else torch.tensor(-1, dtype=torch.int))
|
||||||
|
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
# remove as '.'-character is not allowed in buffers
|
||||||
|
s_name = name.replace('.', '_____')
|
||||||
|
self.m_name2s_name.update({name: s_name})
|
||||||
|
self.register_buffer(s_name, p.clone().detach().data)
|
||||||
|
|
||||||
|
self.collected_params = []
|
||||||
|
|
||||||
|
def forward(self, model):
|
||||||
|
decay = self.decay
|
||||||
|
|
||||||
|
if self.num_updates >= 0:
|
||||||
|
self.num_updates += 1
|
||||||
|
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||||
|
|
||||||
|
one_minus_decay = 1.0 - decay
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
m_param = dict(model.named_parameters())
|
||||||
|
shadow_params = dict(self.named_buffers())
|
||||||
|
|
||||||
|
for key in m_param:
|
||||||
|
if m_param[key].requires_grad:
|
||||||
|
sname = self.m_name2s_name[key]
|
||||||
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
|
else:
|
||||||
|
assert not key in self.m_name2s_name
|
||||||
|
|
||||||
|
def copy_to(self, model):
|
||||||
|
m_param = dict(model.named_parameters())
|
||||||
|
shadow_params = dict(self.named_buffers())
|
||||||
|
for key in m_param:
|
||||||
|
if m_param[key].requires_grad:
|
||||||
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
|
else:
|
||||||
|
assert not key in self.m_name2s_name
|
||||||
|
|
||||||
|
def store(self, model):
|
||||||
|
"""
|
||||||
|
Save the current parameters for restoring later.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
temporarily stored.
|
||||||
|
"""
|
||||||
|
self.collected_params = [param.clone() for param in model.parameters()]
|
||||||
|
|
||||||
|
def restore(self, model):
|
||||||
|
"""
|
||||||
|
Restore the parameters stored with the `store` method.
|
||||||
|
Useful to validate the model with EMA parameters without affecting the
|
||||||
|
original optimization process. Store the parameters before the
|
||||||
|
`copy_to` method. After validation (or model saving), use this to
|
||||||
|
restore the former parameters.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
updated with the stored parameters.
|
||||||
|
"""
|
||||||
|
for c_param, param in zip(self.collected_params, model.parameters()):
|
||||||
|
param.data.copy_(c_param.data)
|
||||||
125
hy3dshape/hy3dshape/utils/misc.py
Normal file
125
hy3dshape/hy3dshape/utils/misc.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
|
||||||
|
config_file = OmegaConf.load(config_file)
|
||||||
|
|
||||||
|
if 'base_config' in config_file.keys():
|
||||||
|
if config_file['base_config'] == "default_base":
|
||||||
|
base_config = OmegaConf.create()
|
||||||
|
# base_config = get_default_config()
|
||||||
|
elif config_file['base_config'].endswith(".yaml"):
|
||||||
|
base_config = get_config_from_file(config_file['base_config'])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
|
||||||
|
|
||||||
|
config_file = {key: value for key, value in config_file if key != "base_config"}
|
||||||
|
|
||||||
|
return OmegaConf.merge(base_config, config_file)
|
||||||
|
|
||||||
|
return config_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_from_str(string, reload=False):
|
||||||
|
module, cls = string.rsplit(".", 1)
|
||||||
|
if reload:
|
||||||
|
module_imp = importlib.import_module(module)
|
||||||
|
importlib.reload(module_imp)
|
||||||
|
return getattr(importlib.import_module(module, package=None), cls)
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_from_config(config):
|
||||||
|
if "target" not in config:
|
||||||
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
|
|
||||||
|
return get_obj_from_str(config["target"])
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_from_config(config, **kwargs):
|
||||||
|
if "target" not in config:
|
||||||
|
raise KeyError("Expected key `target` to instantiate.")
|
||||||
|
|
||||||
|
cls = get_obj_from_str(config["target"])
|
||||||
|
|
||||||
|
if config.get("from_pretrained", None):
|
||||||
|
return cls.from_pretrained(
|
||||||
|
config["from_pretrained"],
|
||||||
|
use_safetensors=config.get('use_safetensors', False),
|
||||||
|
variant=config.get('variant', 'fp16'))
|
||||||
|
|
||||||
|
params = config.get("params", dict())
|
||||||
|
# params.update(kwargs)
|
||||||
|
# instance = cls(**params)
|
||||||
|
kwargs.update(params)
|
||||||
|
instance = cls(**kwargs)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def disabled_train(self, mode=True):
|
||||||
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
|
does not change anymore."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_non_trainable_model(config):
|
||||||
|
model = instantiate_from_config(config)
|
||||||
|
model = model.eval()
|
||||||
|
model.train = disabled_train
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_batch(tensors):
|
||||||
|
"""
|
||||||
|
Performs all_gather operation on the provided tensors.
|
||||||
|
"""
|
||||||
|
# Queue the gathered tensors
|
||||||
|
world_size = get_world_size()
|
||||||
|
# There is no need for reduction in the single-proc case
|
||||||
|
if world_size == 1:
|
||||||
|
return tensors
|
||||||
|
tensor_list = []
|
||||||
|
output_tensor = []
|
||||||
|
for tensor in tensors:
|
||||||
|
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||||
|
dist.all_gather(
|
||||||
|
tensor_all,
|
||||||
|
tensor,
|
||||||
|
async_op=False # performance opt
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_list.append(tensor_all)
|
||||||
|
|
||||||
|
for tensor_all in tensor_list:
|
||||||
|
output_tensor.append(torch.cat(tensor_all, dim=0))
|
||||||
|
return output_tensor
|
||||||
1
hy3dshape/hy3dshape/utils/trainings/__init__.py
Normal file
1
hy3dshape/hy3dshape/utils/trainings/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
213
hy3dshape/hy3dshape/utils/trainings/callback.py
Normal file
213
hy3dshape/hy3dshape/utils/trainings/callback.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
# ------------------------------------------------------------------------------------
|
||||||
|
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
|
||||||
|
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import wandb
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from pathlib import Path
|
||||||
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
from typing import Tuple, Generic, Dict, Callable, Optional, Any
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import pytorch_lightning.loggers
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
from pytorch_lightning.loggers.logger import DummyLogger
|
||||||
|
from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
def node_zero_only(fn: Callable) -> Callable:
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapped_fn(*args, **kwargs) -> Optional[Any]:
|
||||||
|
if node_zero_only.node == 0:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
return None
|
||||||
|
return wrapped_fn
|
||||||
|
|
||||||
|
node_zero_only.node = getattr(node_zero_only, 'node', int(os.environ.get('NODE_RANK', 0)))
|
||||||
|
|
||||||
|
def node_zero_experiment(fn: Callable) -> Callable:
|
||||||
|
"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""
|
||||||
|
@wraps(fn)
|
||||||
|
def experiment(self):
|
||||||
|
@node_zero_only
|
||||||
|
def get_experiment():
|
||||||
|
return fn(self)
|
||||||
|
return get_experiment() or DummyLogger.experiment
|
||||||
|
return experiment
|
||||||
|
|
||||||
|
# customize wandb for node 0 only
|
||||||
|
class MyWandbLogger(WandbLogger):
|
||||||
|
@WandbLogger.experiment.getter
|
||||||
|
@node_zero_experiment
|
||||||
|
def experiment(self):
|
||||||
|
return super().experiment
|
||||||
|
|
||||||
|
class SetupCallback(Callback):
|
||||||
|
def __init__(self, config: DictConfig, exp_config: DictConfig,
|
||||||
|
basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.logdir = basedir / logdir
|
||||||
|
self.ckptdir = basedir / ckptdir
|
||||||
|
self.config = config
|
||||||
|
self.exp_config = exp_config
|
||||||
|
|
||||||
|
# def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
||||||
|
# if trainer.global_rank == 0:
|
||||||
|
# # Create logdirs and save configs
|
||||||
|
# os.makedirs(self.logdir, exist_ok=True)
|
||||||
|
# os.makedirs(self.ckptdir, exist_ok=True)
|
||||||
|
#
|
||||||
|
# print("Experiment config")
|
||||||
|
# print(self.exp_config.pretty())
|
||||||
|
#
|
||||||
|
# print("Model config")
|
||||||
|
# print(self.config.pretty())
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
||||||
|
if trainer.global_rank == 0:
|
||||||
|
# Create logdirs and save configs
|
||||||
|
os.makedirs(self.logdir, exist_ok=True)
|
||||||
|
os.makedirs(self.ckptdir, exist_ok=True)
|
||||||
|
|
||||||
|
# print("Experiment config")
|
||||||
|
# pprint(self.exp_config)
|
||||||
|
#
|
||||||
|
# print("Model config")
|
||||||
|
# pprint(self.config)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLogger(Callback):
|
||||||
|
def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True,
|
||||||
|
increase_log_steps: bool = True) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.batch_freq = batch_frequency
|
||||||
|
self.max_images = max_images
|
||||||
|
self.logger_log_images = {
|
||||||
|
pl.loggers.WandbLogger: self._wandb,
|
||||||
|
pl.loggers.TestTubeLogger: self._testtube,
|
||||||
|
}
|
||||||
|
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||||
|
if not increase_log_steps:
|
||||||
|
self.log_steps = [self.batch_freq]
|
||||||
|
self.clamp = clamp
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def _wandb(self, pl_module, images, batch_idx, split):
|
||||||
|
# raise ValueError("No way wandb")
|
||||||
|
grids = dict()
|
||||||
|
for k in images:
|
||||||
|
grid = torchvision.utils.make_grid(images[k])
|
||||||
|
grids[f"{split}/{k}"] = wandb.Image(grid)
|
||||||
|
pl_module.logger.experiment.log(grids)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def _testtube(self, pl_module, images, batch_idx, split):
|
||||||
|
for k in images:
|
||||||
|
grid = torchvision.utils.make_grid(images[k])
|
||||||
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||||
|
|
||||||
|
tag = f"{split}/{k}"
|
||||||
|
pl_module.logger.experiment.add_image(
|
||||||
|
tag, grid,
|
||||||
|
global_step=pl_module.global_step)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_local(self, save_dir: str, split: str, images: Dict,
|
||||||
|
global_step: int, current_epoch: int, batch_idx: int) -> None:
|
||||||
|
root = os.path.join(save_dir, "results", split)
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
for k in images:
|
||||||
|
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||||
|
|
||||||
|
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||||
|
grid = grid.numpy()
|
||||||
|
grid = (grid * 255).astype(np.uint8)
|
||||||
|
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
||||||
|
k,
|
||||||
|
global_step,
|
||||||
|
current_epoch,
|
||||||
|
batch_idx)
|
||||||
|
path = os.path.join(root, filename)
|
||||||
|
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||||
|
Image.fromarray(grid).save(path)
|
||||||
|
|
||||||
|
def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int,
|
||||||
|
split: str = "train") -> None:
|
||||||
|
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
||||||
|
hasattr(pl_module, "log_images") and
|
||||||
|
callable(pl_module.log_images) and
|
||||||
|
self.max_images > 0):
|
||||||
|
logger = type(pl_module.logger)
|
||||||
|
|
||||||
|
is_train = pl_module.training
|
||||||
|
if is_train:
|
||||||
|
pl_module.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
||||||
|
|
||||||
|
for k in images:
|
||||||
|
N = min(images[k].shape[0], self.max_images)
|
||||||
|
images[k] = images[k][:N].detach().cpu()
|
||||||
|
if self.clamp:
|
||||||
|
images[k] = images[k].clamp(0, 1)
|
||||||
|
|
||||||
|
self.log_local(pl_module.logger.save_dir, split, images,
|
||||||
|
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||||
|
|
||||||
|
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||||
|
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
pl_module.train()
|
||||||
|
|
||||||
|
def check_frequency(self, batch_idx: int) -> bool:
|
||||||
|
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
||||||
|
try:
|
||||||
|
self.log_steps.pop(0)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||||
|
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:
|
||||||
|
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||||
|
|
||||||
|
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||||
|
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],
|
||||||
|
dataloader_idx: int, batch_idx: int) -> None:
|
||||||
|
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||||
|
|
||||||
|
|
||||||
|
class CUDACallback(Callback):
|
||||||
|
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||||
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
|
# Reset the memory use counter
|
||||||
|
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
||||||
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||||
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
|
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
||||||
|
epoch_time = time.time() - self.start_time
|
||||||
|
|
||||||
|
try:
|
||||||
|
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||||
|
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||||
|
|
||||||
|
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||||
|
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
53
hy3dshape/hy3dshape/utils/trainings/lr_scheduler.py
Normal file
53
hy3dshape/hy3dshape/utils/trainings/lr_scheduler.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScheduler(object):
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineFactorScheduler(BaseScheduler):
|
||||||
|
"""
|
||||||
|
note: use with a base_lr of 1.0
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, f_min, f_max, f_start, max_decay_steps, verbosity_interval=0, **ignore_kwargs):
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.f_start = f_start
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max
|
||||||
|
self.lr_max_decay_steps = max_decay_steps
|
||||||
|
self.last_f = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0:
|
||||||
|
print(f"current step: {n}, recent lr-multiplier: {self.f_start}")
|
||||||
|
if n < self.lr_warm_up_steps:
|
||||||
|
f = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||||
|
t = min(t, 1.0)
|
||||||
|
f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi))
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n, **kwargs)
|
||||||
128
hy3dshape/hy3dshape/utils/trainings/mesh.py
Normal file
128
hy3dshape/hy3dshape/utils/trainings/mesh.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
|
||||||
|
def save_obj(pointnp_px3, facenp_fx3, fname):
|
||||||
|
fid = open(fname, "w")
|
||||||
|
write_str = ""
|
||||||
|
for pidx, p in enumerate(pointnp_px3):
|
||||||
|
pp = p
|
||||||
|
write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
|
||||||
|
|
||||||
|
for i, f in enumerate(facenp_fx3):
|
||||||
|
f1 = f + 1
|
||||||
|
write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
|
||||||
|
fid.write(write_str)
|
||||||
|
fid.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
|
||||||
|
fol, na = os.path.split(fname)
|
||||||
|
na, _ = os.path.splitext(na)
|
||||||
|
|
||||||
|
matname = "%s/%s.mtl" % (fol, na)
|
||||||
|
fid = open(matname, "w")
|
||||||
|
fid.write("newmtl material_0\n")
|
||||||
|
fid.write("Kd 1 1 1\n")
|
||||||
|
fid.write("Ka 0 0 0\n")
|
||||||
|
fid.write("Ks 0.4 0.4 0.4\n")
|
||||||
|
fid.write("Ns 10\n")
|
||||||
|
fid.write("illum 2\n")
|
||||||
|
fid.write("map_Kd %s.png\n" % na)
|
||||||
|
fid.close()
|
||||||
|
####
|
||||||
|
|
||||||
|
fid = open(fname, "w")
|
||||||
|
fid.write("mtllib %s.mtl\n" % na)
|
||||||
|
|
||||||
|
for pidx, p3 in enumerate(pointnp_px3):
|
||||||
|
pp = p3
|
||||||
|
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
|
||||||
|
|
||||||
|
for pidx, p2 in enumerate(tcoords_px2):
|
||||||
|
pp = p2
|
||||||
|
fid.write("vt %f %f\n" % (pp[0], pp[1]))
|
||||||
|
|
||||||
|
fid.write("usemtl material_0\n")
|
||||||
|
for i, f in enumerate(facenp_fx3):
|
||||||
|
f1 = f + 1
|
||||||
|
f2 = facetex_fx3[i] + 1
|
||||||
|
fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
||||||
|
fid.close()
|
||||||
|
|
||||||
|
PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
|
||||||
|
os.path.join(fol, "%s.png" % na))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MeshOutput(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mesh_v: np.ndarray,
|
||||||
|
mesh_f: np.ndarray,
|
||||||
|
vertex_colors: Optional[np.ndarray] = None,
|
||||||
|
uvs: Optional[np.ndarray] = None,
|
||||||
|
mesh_tex_idx: Optional[np.ndarray] = None,
|
||||||
|
tex_map: Optional[np.ndarray] = None):
|
||||||
|
|
||||||
|
self.mesh_v = mesh_v
|
||||||
|
self.mesh_f = mesh_f
|
||||||
|
self.vertex_colors = vertex_colors
|
||||||
|
self.uvs = uvs
|
||||||
|
self.mesh_tex_idx = mesh_tex_idx
|
||||||
|
self.tex_map = tex_map
|
||||||
|
|
||||||
|
def contain_uv_texture(self):
|
||||||
|
return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
|
||||||
|
|
||||||
|
def contain_vertex_colors(self):
|
||||||
|
return self.vertex_colors is not None
|
||||||
|
|
||||||
|
def export(self, fname):
|
||||||
|
|
||||||
|
if self.contain_uv_texture():
|
||||||
|
savemeshtes2(
|
||||||
|
self.mesh_v,
|
||||||
|
self.uvs,
|
||||||
|
self.mesh_f,
|
||||||
|
self.mesh_tex_idx,
|
||||||
|
self.tex_map,
|
||||||
|
fname
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.contain_vertex_colors():
|
||||||
|
mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
|
||||||
|
mesh_obj.export(fname)
|
||||||
|
|
||||||
|
else:
|
||||||
|
save_obj(
|
||||||
|
self.mesh_v,
|
||||||
|
self.mesh_f,
|
||||||
|
fname
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
336
hy3dshape/hy3dshape/utils/trainings/mesh_log_callback.py
Normal file
336
hy3dshape/hy3dshape/utils/trainings/mesh_log_callback.py
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Tuple, Generic, Dict, List, Union, Optional
|
||||||
|
|
||||||
|
import trimesh
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import pytorch_lightning.loggers
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
from hy3dshape.pipelines import export_to_trimesh
|
||||||
|
from hy3dshape.utils.trainings.mesh import MeshOutput
|
||||||
|
from hy3dshape.utils.visualizers import html_util
|
||||||
|
from hy3dshape.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConditionalASLDiffuserLogger(Callback):
|
||||||
|
def __init__(self,
|
||||||
|
step_frequency: int,
|
||||||
|
num_samples: int = 1,
|
||||||
|
mean: Optional[Union[List[float], Tuple[float]]] = None,
|
||||||
|
std: Optional[Union[List[float], Tuple[float]]] = None,
|
||||||
|
bounds: Union[List[float], Tuple[float]] = (-1.1, -1.1, -1.1, 1.1, 1.1, 1.1),
|
||||||
|
**kwargs) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.bbox_size = np.array(bounds[3:6]) - np.array(bounds[0:3])
|
||||||
|
|
||||||
|
if mean is not None:
|
||||||
|
mean = np.asarray(mean)
|
||||||
|
|
||||||
|
if std is not None:
|
||||||
|
std = np.asarray(std)
|
||||||
|
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
self.step_freq = step_frequency
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.has_train_logged = False
|
||||||
|
self.logger_log_images = {
|
||||||
|
pl.loggers.WandbLogger: self._wandb,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def _wandb(self, pl_module, images, batch_idx, split):
|
||||||
|
# raise ValueError("No way wandb")
|
||||||
|
grids = dict()
|
||||||
|
for k in images:
|
||||||
|
grid = torchvision.utils.make_grid(images[k])
|
||||||
|
grids[f"{split}/{k}"] = wandb.Image(grid)
|
||||||
|
pl_module.logger.experiment.log(grids)
|
||||||
|
|
||||||
|
def log_local(self,
|
||||||
|
outputs: List[List['Latent2MeshOutput']],
|
||||||
|
images: Union[np.ndarray, List[np.ndarray]],
|
||||||
|
description: List[str],
|
||||||
|
keys: List[str],
|
||||||
|
save_dir: str, split: str,
|
||||||
|
global_step: int, current_epoch: int, batch_idx: int,
|
||||||
|
prog_bar: bool = False,
|
||||||
|
multi_views=None, # yf ...
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
folder = "gs-{:010}_e-{:06}_b-{:06}".format(global_step, current_epoch, batch_idx)
|
||||||
|
visual_dir = os.path.join(save_dir, "visuals", split, folder)
|
||||||
|
os.makedirs(visual_dir, exist_ok=True)
|
||||||
|
|
||||||
|
num_samples = len(images)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
key_i = keys[i]
|
||||||
|
image_i = self.denormalize_image(images[i])
|
||||||
|
shape_tag_i = description[i]
|
||||||
|
|
||||||
|
for j in range(1):
|
||||||
|
mesh = outputs[j][i]
|
||||||
|
if mesh is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mesh_v = mesh.mesh_v.copy()
|
||||||
|
mesh_v[:, 0] += j * np.max(self.bbox_size)
|
||||||
|
self.viewer.add_mesh(mesh_v, mesh.mesh_f)
|
||||||
|
|
||||||
|
image_tag = html_util.to_image_embed_tag(image_i)
|
||||||
|
mesh_tag = self.viewer.to_html(html_frame=False)
|
||||||
|
|
||||||
|
table_tag = f"""
|
||||||
|
<table border = "1">
|
||||||
|
<caption> {shape_tag_i} - {key_i} </caption>
|
||||||
|
<caption> Input Image | Generated Mesh </caption>
|
||||||
|
<tr>
|
||||||
|
<td>{image_tag}</td>
|
||||||
|
<td>{mesh_tag}</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
if multi_views is not None:
|
||||||
|
multi_views_i = self.make_grid(multi_views[i])
|
||||||
|
views_tag = html_util.to_image_embed_tag(self.denormalize_image(multi_views_i))
|
||||||
|
table_tag = f"""
|
||||||
|
<table border = "1">
|
||||||
|
<caption> {shape_tag_i} - {key_i} </caption>
|
||||||
|
<caption> Input Image | Generated Mesh </caption>
|
||||||
|
<tr>
|
||||||
|
<td>{image_tag}</td>
|
||||||
|
<td>{views_tag}</td>
|
||||||
|
<td>{mesh_tag}</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
html_frame = html_util.to_html_frame(table_tag)
|
||||||
|
if len(key_i) > 100:
|
||||||
|
key_i = key_i[:100]
|
||||||
|
with open(os.path.join(visual_dir, f"{key_i}.html"), "w") as writer:
|
||||||
|
writer.write(html_frame)
|
||||||
|
|
||||||
|
self.viewer.reset()
|
||||||
|
|
||||||
|
def log_sample(self,
|
||||||
|
pl_module: pl.LightningModule,
|
||||||
|
batch: Dict[str, torch.FloatTensor],
|
||||||
|
batch_idx: int,
|
||||||
|
split: str = "train") -> None:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pl_module:
|
||||||
|
batch (dict): the batch sample information, and it contains:
|
||||||
|
- surface (torch.FloatTensor):
|
||||||
|
- image (torch.FloatTensor):
|
||||||
|
batch_idx (int):
|
||||||
|
split (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_train = pl_module.training
|
||||||
|
if is_train:
|
||||||
|
pl_module.eval()
|
||||||
|
|
||||||
|
batch_size = len(batch["surface"])
|
||||||
|
replace = batch_size < self.num_samples
|
||||||
|
ids = np.random.choice(batch_size, self.num_samples, replace=replace)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# run text to mesh
|
||||||
|
# keys = [batch["__key__"][i] for i in ids]
|
||||||
|
keys = [f'key_{i}' for i in ids]
|
||||||
|
# texts = [batch["text"][i] for i in ids]
|
||||||
|
texts = [f'text_{i}'for i in ids]
|
||||||
|
# description = [batch["description"][i] for i in ids]
|
||||||
|
description = [f'desc_{i}' for i in ids]
|
||||||
|
images = batch["image"][ids]
|
||||||
|
mask_input = batch["mask"][ids] if 'mask' in batch else None
|
||||||
|
sample_batch = {
|
||||||
|
"__key__": keys,
|
||||||
|
"image": images,
|
||||||
|
'text': texts,
|
||||||
|
'mask': mask_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
# if 'cam_parm' in batch:
|
||||||
|
# sample_batch['cam_parm'] = batch['cam_parm'][ids]
|
||||||
|
|
||||||
|
# if 'multi_views' in batch: # yf ...
|
||||||
|
# sample_batch['multi_views'] = batch['multi_views'][ids]
|
||||||
|
|
||||||
|
outputs = pl_module.sample(
|
||||||
|
batch=sample_batch,
|
||||||
|
output_type='latents2mesh'
|
||||||
|
)
|
||||||
|
|
||||||
|
images = images.cpu().float().numpy()
|
||||||
|
# images = self.denormalize_image(images)
|
||||||
|
# images = np.transpose(images, (0, 2, 3, 1))
|
||||||
|
# images = ((images + 1) / 2 * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
self.log_local(outputs, images, description, keys, pl_module.logger.save_dir, split,
|
||||||
|
pl_module.global_step, pl_module.current_epoch, batch_idx, prog_bar=False,
|
||||||
|
multi_views=sample_batch.get('multi_views'))
|
||||||
|
|
||||||
|
if is_train: pl_module.train()
|
||||||
|
|
||||||
|
def make_grid(self, images): # return (3,h,w) in (0,1) ...
|
||||||
|
images_resized = []
|
||||||
|
for img in images:
|
||||||
|
img_resized = torchvision.transforms.functional.resize(img, (320, 320))
|
||||||
|
images_resized.append(img_resized)
|
||||||
|
image = torchvision.utils.make_grid(images_resized, nrow=2, padding=5, pad_value=255)
|
||||||
|
|
||||||
|
image = image.cpu().numpy()
|
||||||
|
# image = np.transpose(image, (1, 2, 0))
|
||||||
|
# image = (image * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def check_frequency(self, step: int) -> bool:
|
||||||
|
if step % self.step_freq == 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||||
|
outputs: Generic, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> None:
|
||||||
|
|
||||||
|
if (self.check_frequency(pl_module.global_step) and # batch_idx % self.batch_freq == 0
|
||||||
|
hasattr(pl_module, "sample") and
|
||||||
|
callable(pl_module.sample) and
|
||||||
|
self.num_samples > 0):
|
||||||
|
self.log_sample(pl_module, batch, batch_idx, split="train")
|
||||||
|
self.has_train_logged = True
|
||||||
|
|
||||||
|
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
||||||
|
outputs: Generic, batch: Dict[str, torch.FloatTensor],
|
||||||
|
dataloader_idx: int, batch_idx: int) -> None:
|
||||||
|
|
||||||
|
if self.has_train_logged:
|
||||||
|
self.log_sample(pl_module, batch, batch_idx, split="val")
|
||||||
|
self.has_train_logged = False
|
||||||
|
|
||||||
|
def denormalize_image(self, image):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): [3, h, w]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
image (np.ndarray): [h, w, 3], np.uint8, [0, 255].
|
||||||
|
"""
|
||||||
|
# image = np.transpose(image, (0, 2, 3, 1))
|
||||||
|
image = np.transpose(image, (1, 2, 0))
|
||||||
|
|
||||||
|
if self.std is not None:
|
||||||
|
image = image * self.std
|
||||||
|
|
||||||
|
if self.mean is not None:
|
||||||
|
image = image + self.mean
|
||||||
|
|
||||||
|
image = (image * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConditionalFixASLDiffuserLogger(Callback):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
step_frequency: int,
|
||||||
|
test_data_path: str,
|
||||||
|
max_size: int = None,
|
||||||
|
save_dir: str = 'infer',
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.step_freq = step_frequency
|
||||||
|
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
||||||
|
|
||||||
|
self.test_data_path = test_data_path
|
||||||
|
with open(self.test_data_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.file_list = data['file_list']
|
||||||
|
self.file_folder = data['file_folder']
|
||||||
|
if max_size is not None:
|
||||||
|
self.file_list = self.file_list[:max_size]
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.save_dir = save_dir
|
||||||
|
|
||||||
|
def on_train_batch_end(
|
||||||
|
self,
|
||||||
|
trainer: pl.trainer.Trainer,
|
||||||
|
pl_module: pl.LightningModule,
|
||||||
|
outputs: Generic,
|
||||||
|
batch: Dict[str, torch.FloatTensor],
|
||||||
|
batch_idx: int,
|
||||||
|
):
|
||||||
|
if pl_module.global_step % self.step_freq == 0:
|
||||||
|
is_train = pl_module.training
|
||||||
|
if is_train:
|
||||||
|
pl_module.eval()
|
||||||
|
|
||||||
|
folder_path = self.file_folder
|
||||||
|
folder_name = os.path.basename(folder_path)
|
||||||
|
folder = "gs-{:010}_e-{:06}_b-{:06}".format(pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||||
|
visual_dir = os.path.join(pl_module.logger.save_dir, self.save_dir, folder, folder_name)
|
||||||
|
os.makedirs(visual_dir, exist_ok=True)
|
||||||
|
|
||||||
|
image_paths = self.file_list
|
||||||
|
chunk_size = math.ceil(len(image_paths) / trainer.world_size)
|
||||||
|
if pl_module.global_rank == trainer.world_size - 1:
|
||||||
|
image_paths = image_paths[pl_module.global_rank * chunk_size:]
|
||||||
|
else:
|
||||||
|
image_paths = image_paths[pl_module.global_rank * chunk_size:(pl_module.global_rank + 1) * chunk_size]
|
||||||
|
|
||||||
|
print(f'Rank{pl_module.global_rank}: processing {len(image_paths)}|{len(self.file_list)} images')
|
||||||
|
for image_path in image_paths:
|
||||||
|
if folder_path in image_path:
|
||||||
|
save_path = image_path.replace(folder_path, visual_dir)
|
||||||
|
else:
|
||||||
|
save_path = os.path.join(visual_dir, os.path.basename(image_path))
|
||||||
|
save_path = os.path.splitext(save_path)[0] + '.glb'
|
||||||
|
|
||||||
|
print(image_path)
|
||||||
|
with torch.no_grad():
|
||||||
|
mesh = pl_module.sample(batch={"image": image_path}, **self.kwargs)[0][0]
|
||||||
|
if isinstance(mesh, tuple) and len(mesh)==2:
|
||||||
|
mesh = export_to_trimesh(mesh)
|
||||||
|
elif isinstance(mesh, trimesh.Trimesh):
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
mesh.export(save_path)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
pl_module.train()
|
||||||
78
hy3dshape/hy3dshape/utils/trainings/peft.py
Normal file
78
hy3dshape/hy3dshape/utils/trainings/peft.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
from omegaconf import OmegaConf, ListConfig
|
||||||
|
|
||||||
|
class PeftSaveCallback(Callback):
|
||||||
|
def __init__(self, peft_model, save_dir: str, save_every_n_steps: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.peft_model = peft_model
|
||||||
|
self.save_dir = save_dir
|
||||||
|
self.save_every_n_steps = save_every_n_steps
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def recursive_convert(self, obj):
|
||||||
|
from omegaconf import OmegaConf, ListConfig
|
||||||
|
if isinstance(obj, (OmegaConf, ListConfig)):
|
||||||
|
return OmegaConf.to_container(obj, resolve=True)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: self.recursive_convert(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self.recursive_convert(i) for i in obj]
|
||||||
|
elif isinstance(obj, type):
|
||||||
|
# 避免修改类对象
|
||||||
|
return obj
|
||||||
|
elif hasattr(obj, '__dict__'):
|
||||||
|
for attr_name, attr_value in vars(obj).items():
|
||||||
|
setattr(obj, attr_name, self.recursive_convert(attr_value))
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# def recursive_convert(self, obj):
|
||||||
|
# if isinstance(obj, (OmegaConf, ListConfig)):
|
||||||
|
# return OmegaConf.to_container(obj, resolve=True)
|
||||||
|
# elif isinstance(obj, dict):
|
||||||
|
# return {k: self.recursive_convert(v) for k, v in obj.items()}
|
||||||
|
# elif isinstance(obj, list):
|
||||||
|
# return [self.recursive_convert(i) for i in obj]
|
||||||
|
# elif hasattr(obj, '__dict__'):
|
||||||
|
# for attr_name, attr_value in vars(obj).items():
|
||||||
|
# setattr(obj, attr_name, self.recursive_convert(attr_value))
|
||||||
|
# return obj
|
||||||
|
# else:
|
||||||
|
# return obj
|
||||||
|
|
||||||
|
def _convert_peft_config(self):
|
||||||
|
pc = self.peft_model.peft_config
|
||||||
|
self.peft_model.peft_config = self.recursive_convert(pc)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
self._convert_peft_config()
|
||||||
|
save_path = os.path.join(self.save_dir, f"epoch_{trainer.current_epoch}")
|
||||||
|
self.peft_model.save_pretrained(save_path)
|
||||||
|
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
|
if self.save_every_n_steps is not None:
|
||||||
|
global_step = trainer.global_step
|
||||||
|
if global_step % self.save_every_n_steps == 0 and global_step > 0:
|
||||||
|
self._convert_peft_config()
|
||||||
|
save_path = os.path.join(self.save_dir, f"step_{global_step}")
|
||||||
|
self.peft_model.save_pretrained(save_path)
|
||||||
|
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
||||||
128
hy3dshape/hy3dshape/utils/utils.py
Normal file
128
hy3dshape/hy3dshape/utils/utils.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger('hy3dgen.shapgen')
|
||||||
|
|
||||||
|
|
||||||
|
class synchronize_timer:
|
||||||
|
""" Synchronized timer to count the inference time of `nn.Module.forward`.
|
||||||
|
|
||||||
|
Supports both context manager and decorator usage.
|
||||||
|
|
||||||
|
Example as context manager:
|
||||||
|
```python
|
||||||
|
with synchronize_timer('name') as t:
|
||||||
|
run()
|
||||||
|
```
|
||||||
|
|
||||||
|
Example as decorator:
|
||||||
|
```python
|
||||||
|
@synchronize_timer('Export to trimesh')
|
||||||
|
def export_to_trimesh(mesh_output):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry: start timing."""
|
||||||
|
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||||
|
self.start = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.end = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.start.record()
|
||||||
|
return lambda: self.time
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||||
|
"""Context manager exit: stop timing and log results."""
|
||||||
|
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
||||||
|
self.end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.time = self.start.elapsed_time(self.end)
|
||||||
|
if self.name is not None:
|
||||||
|
logger.info(f'{self.name} takes {self.time} ms')
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
"""Decorator: wrap the function to time its execution."""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with self:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def smart_load_model(
|
||||||
|
model_path,
|
||||||
|
subfolder,
|
||||||
|
use_safetensors,
|
||||||
|
variant,
|
||||||
|
):
|
||||||
|
original_model_path = model_path
|
||||||
|
# try local path
|
||||||
|
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
||||||
|
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
|
||||||
|
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
||||||
|
logger.info(f'Try to load model from local path: {model_path}')
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.info('Model path not exists, try to download from huggingface')
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
# 只下载指定子目录
|
||||||
|
path = snapshot_download(
|
||||||
|
repo_id=original_model_path,
|
||||||
|
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
||||||
|
local_dir=model_fld
|
||||||
|
)
|
||||||
|
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"You need to install HuggingFace Hub to load models from the hub."
|
||||||
|
)
|
||||||
|
raise RuntimeError(f"Model path {model_path} not found")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
||||||
|
|
||||||
|
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
||||||
|
variant = '' if variant is None else f'.{variant}'
|
||||||
|
ckpt_name = f'model{variant}.{extension}'
|
||||||
|
config_path = os.path.join(model_path, 'config.yaml')
|
||||||
|
ckpt_path = os.path.join(model_path, ckpt_name)
|
||||||
|
return config_path, ckpt_path
|
||||||
1
hy3dshape/hy3dshape/utils/visualizers/__init__.py
Normal file
1
hy3dshape/hy3dshape/utils/visualizers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
57
hy3dshape/hy3dshape/utils/visualizers/color_util.py
Normal file
57
hy3dshape/hy3dshape/utils/visualizers/color_util.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
|
||||||
|
colormap = plt.cm.get_cmap(colormap)
|
||||||
|
if normalize:
|
||||||
|
vmin = np.min(inp)
|
||||||
|
vmax = np.max(inp)
|
||||||
|
|
||||||
|
norm = plt.Normalize(vmin, vmax)
|
||||||
|
return colormap(norm(inp))[:, :3]
|
||||||
|
|
||||||
|
|
||||||
|
def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
|
||||||
|
# tex dims need to be power of two.
|
||||||
|
array = np.ones((width, height, 3), dtype='float32')
|
||||||
|
|
||||||
|
# width in texels of each checker
|
||||||
|
checker_w = width / n_checkers_x
|
||||||
|
checker_h = height / n_checkers_y
|
||||||
|
|
||||||
|
for y in range(height):
|
||||||
|
for x in range(width):
|
||||||
|
color_key = int(x / checker_w) + int(y / checker_h)
|
||||||
|
if color_key % 2 == 0:
|
||||||
|
array[x, y, :] = [1., 0.874, 0.0]
|
||||||
|
else:
|
||||||
|
array[x, y, :] = [0., 0., 0.]
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def gen_circle(width=256, height=256):
|
||||||
|
xx, yy = np.mgrid[:width, :height]
|
||||||
|
circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
|
||||||
|
array = np.ones((width, height, 4), dtype='float32')
|
||||||
|
array[:, :, 0] = (circle <= width)
|
||||||
|
array[:, :, 1] = (circle <= width)
|
||||||
|
array[:, :, 2] = (circle <= width)
|
||||||
|
array[:, :, 3] = circle <= width
|
||||||
|
return array
|
||||||
|
|
||||||
64
hy3dshape/hy3dshape/utils/visualizers/html_util.py
Normal file
64
hy3dshape/hy3dshape/utils/visualizers/html_util.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def to_html_frame(content):
|
||||||
|
|
||||||
|
html_frame = f"""
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
{content}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return html_frame
|
||||||
|
|
||||||
|
|
||||||
|
def to_single_row_table(caption: str, content: str):
|
||||||
|
|
||||||
|
table_html = f"""
|
||||||
|
<table border = "1">
|
||||||
|
<caption>{caption}</caption>
|
||||||
|
<tr>
|
||||||
|
<td>{content}</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return table_html
|
||||||
|
|
||||||
|
|
||||||
|
def to_image_embed_tag(image: np.ndarray):
|
||||||
|
|
||||||
|
# Convert np.ndarray to bytes
|
||||||
|
img = Image.fromarray(image)
|
||||||
|
raw_bytes = io.BytesIO()
|
||||||
|
img.save(raw_bytes, "PNG")
|
||||||
|
|
||||||
|
# Encode bytes to base64
|
||||||
|
image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
image_tag = f"""
|
||||||
|
<img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
|
||||||
|
"""
|
||||||
|
|
||||||
|
return image_tag
|
||||||
549
hy3dshape/hy3dshape/utils/visualizers/pythreejs_viewer.py
Normal file
549
hy3dshape/hy3dshape/utils/visualizers/pythreejs_viewer.py
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from ipywidgets import embed
|
||||||
|
import pythreejs as p3s
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .color_util import get_colors, gen_circle, gen_checkers
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
|
||||||
|
|
||||||
|
|
||||||
|
class PyThreeJSViewer(object):
|
||||||
|
|
||||||
|
def __init__(self, settings, render_mode="WEBSITE"):
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.__update_settings(settings)
|
||||||
|
self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
|
||||||
|
self._light2 = p3s.AmbientLight(intensity=0.5)
|
||||||
|
self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
|
||||||
|
aspect=self.__s["width"] / self.__s["height"], children=[self._light])
|
||||||
|
self._orbit = p3s.OrbitControls(controlling=self._cam)
|
||||||
|
self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
|
||||||
|
self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
|
||||||
|
width=self.__s["width"], height=self.__s["height"],
|
||||||
|
antialias=self.__s["antialias"])
|
||||||
|
|
||||||
|
self.__objects = {}
|
||||||
|
self.__cnt = 0
|
||||||
|
|
||||||
|
def jupyter_mode(self):
|
||||||
|
self.render_mode = "JUPYTER"
|
||||||
|
|
||||||
|
def offline(self):
|
||||||
|
self.render_mode = "OFFLINE"
|
||||||
|
|
||||||
|
def website(self):
|
||||||
|
self.render_mode = "WEBSITE"
|
||||||
|
|
||||||
|
def __get_shading(self, shading):
|
||||||
|
shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
|
||||||
|
"side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
|
||||||
|
"bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
|
||||||
|
"line_width": 1.0, "line_color": "black",
|
||||||
|
"point_color": "red", "point_size": 0.01, "point_shape": "circle",
|
||||||
|
"text_color": "red"
|
||||||
|
}
|
||||||
|
for k in shading:
|
||||||
|
shad[k] = shading[k]
|
||||||
|
return shad
|
||||||
|
|
||||||
|
def __update_settings(self, settings={}):
|
||||||
|
sett = {"width": 1600, "height": 800, "antialias": True, "scale": 1.5, "background": "#ffffff",
|
||||||
|
"fov": 30}
|
||||||
|
for k in settings:
|
||||||
|
sett[k] = settings[k]
|
||||||
|
self.__s = sett
|
||||||
|
|
||||||
|
def __add_object(self, obj, parent=None):
|
||||||
|
if not parent: # Object is added to global scene and objects dict
|
||||||
|
self.__objects[self.__cnt] = obj
|
||||||
|
self.__cnt += 1
|
||||||
|
self._scene.add(obj["mesh"])
|
||||||
|
else: # Object is added to parent object and NOT to objects dict
|
||||||
|
parent.add(obj["mesh"])
|
||||||
|
|
||||||
|
self.__update_view()
|
||||||
|
|
||||||
|
if self.render_mode == "JUPYTER":
|
||||||
|
return self.__cnt - 1
|
||||||
|
elif self.render_mode == "WEBSITE":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __add_line_geometry(self, lines, shading, obj=None):
|
||||||
|
lines = lines.astype("float32", copy=False)
|
||||||
|
mi = np.min(lines, axis=0)
|
||||||
|
ma = np.max(lines, axis=0)
|
||||||
|
|
||||||
|
geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
|
||||||
|
material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
|
||||||
|
# , vertexColors='VertexColors'),
|
||||||
|
lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
|
||||||
|
line_obj = {"geometry": geometry, "mesh": lines, "material": material,
|
||||||
|
"max": ma, "min": mi, "type": "Lines", "wireframe": None}
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
return self.__add_object(line_obj, obj), line_obj
|
||||||
|
else:
|
||||||
|
return self.__add_object(line_obj)
|
||||||
|
|
||||||
|
def __update_view(self):
|
||||||
|
if len(self.__objects) == 0:
|
||||||
|
return
|
||||||
|
ma = np.zeros((len(self.__objects), 3))
|
||||||
|
mi = np.zeros((len(self.__objects), 3))
|
||||||
|
for r, obj in enumerate(self.__objects):
|
||||||
|
ma[r] = self.__objects[obj]["max"]
|
||||||
|
mi[r] = self.__objects[obj]["min"]
|
||||||
|
ma = np.max(ma, axis=0)
|
||||||
|
mi = np.min(mi, axis=0)
|
||||||
|
diag = np.linalg.norm(ma - mi)
|
||||||
|
mean = ((ma - mi) / 2 + mi).tolist()
|
||||||
|
scale = self.__s["scale"] * (diag)
|
||||||
|
self._orbit.target = mean
|
||||||
|
self._cam.lookAt(mean)
|
||||||
|
self._cam.position = [mean[0], mean[1], mean[2] + scale]
|
||||||
|
self._light.position = [mean[0], mean[1], mean[2] + scale]
|
||||||
|
|
||||||
|
self._orbit.exec_three_obj_method('update')
|
||||||
|
self._cam.exec_three_obj_method('updateProjectionMatrix')
|
||||||
|
|
||||||
|
def __get_bbox(self, v):
|
||||||
|
m = np.min(v, axis=0)
|
||||||
|
M = np.max(v, axis=0)
|
||||||
|
|
||||||
|
# Corners of the bounding box
|
||||||
|
v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
|
||||||
|
[m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
|
||||||
|
|
||||||
|
f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
|
||||||
|
[0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
|
||||||
|
return v_box, f_box
|
||||||
|
|
||||||
|
def __get_colors(self, v, f, c, sh):
|
||||||
|
coloring = "VertexColors"
|
||||||
|
if type(c) == np.ndarray and c.size == 3: # Single color
|
||||||
|
colors = np.ones_like(v)
|
||||||
|
colors[:, 0] = c[0]
|
||||||
|
colors[:, 1] = c[1]
|
||||||
|
colors[:, 2] = c[2]
|
||||||
|
# print("Single colors")
|
||||||
|
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
|
||||||
|
if c.shape[0] == f.shape[0]: # faces
|
||||||
|
colors = np.hstack([c, c, c]).reshape((-1, 3))
|
||||||
|
coloring = "FaceColors"
|
||||||
|
# print("Face color values")
|
||||||
|
elif c.shape[0] == v.shape[0]: # vertices
|
||||||
|
colors = c
|
||||||
|
# print("Vertex color values")
|
||||||
|
else: # Wrong size, fallback
|
||||||
|
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||||
|
colors = np.ones_like(v)
|
||||||
|
colors[:, 0] = 1.0
|
||||||
|
colors[:, 1] = 0.874
|
||||||
|
colors[:, 2] = 0.0
|
||||||
|
elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
|
||||||
|
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||||
|
cc = get_colors(c, sh["colormap"], normalize=normalize,
|
||||||
|
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||||
|
# print(cc.shape)
|
||||||
|
colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
|
||||||
|
coloring = "FaceColors"
|
||||||
|
# print("Face function values")
|
||||||
|
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
|
||||||
|
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||||
|
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
||||||
|
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||||
|
# print("Vertex function values")
|
||||||
|
|
||||||
|
else:
|
||||||
|
colors = np.ones_like(v)
|
||||||
|
colors[:, 0] = 1.0
|
||||||
|
colors[:, 1] = 0.874
|
||||||
|
colors[:, 2] = 0.0
|
||||||
|
|
||||||
|
# No color
|
||||||
|
if c is not None:
|
||||||
|
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||||
|
|
||||||
|
return colors, coloring
|
||||||
|
|
||||||
|
def __get_point_colors(self, v, c, sh):
|
||||||
|
v_color = True
|
||||||
|
if c is None: # No color given, use global color
|
||||||
|
# conv = mpl.colors.ColorConverter()
|
||||||
|
colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
|
||||||
|
v_color = False
|
||||||
|
elif isinstance(c, str): # No color given, use global color
|
||||||
|
# conv = mpl.colors.ColorConverter()
|
||||||
|
colors = c # np.array(conv.to_rgb(c))
|
||||||
|
v_color = False
|
||||||
|
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
|
||||||
|
# Point color
|
||||||
|
colors = c.astype("float32", copy=False)
|
||||||
|
|
||||||
|
elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
|
||||||
|
# Function values for vertices, but the colors are features
|
||||||
|
c_norm = np.linalg.norm(c, ord=2, axis=-1)
|
||||||
|
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||||
|
colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
|
||||||
|
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||||
|
colors = colors.astype("float32", copy=False)
|
||||||
|
|
||||||
|
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
|
||||||
|
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
||||||
|
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
||||||
|
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
||||||
|
colors = colors.astype("float32", copy=False)
|
||||||
|
# print("Vertex function values")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
||||||
|
colors = sh["point_color"]
|
||||||
|
v_color = False
|
||||||
|
|
||||||
|
return colors, v_color
|
||||||
|
|
||||||
|
def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
|
||||||
|
shading.update(kwargs)
|
||||||
|
sh = self.__get_shading(shading)
|
||||||
|
mesh_obj = {}
|
||||||
|
|
||||||
|
# it is a tet
|
||||||
|
if v.shape[1] == 3 and f.shape[1] == 4:
|
||||||
|
f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
|
||||||
|
for i in range(f.shape[0]):
|
||||||
|
f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
|
||||||
|
f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
|
||||||
|
f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
|
||||||
|
f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
|
||||||
|
f = f_tmp
|
||||||
|
|
||||||
|
if v.shape[1] == 2:
|
||||||
|
v = np.append(v, np.zeros([v.shape[0], 1]), 1)
|
||||||
|
|
||||||
|
# Type adjustment vertices
|
||||||
|
v = v.astype("float32", copy=False)
|
||||||
|
|
||||||
|
# Color setup
|
||||||
|
colors, coloring = self.__get_colors(v, f, c, sh)
|
||||||
|
|
||||||
|
# Type adjustment faces and colors
|
||||||
|
c = colors.astype("float32", copy=False)
|
||||||
|
|
||||||
|
# Material and geometry setup
|
||||||
|
ba_dict = {"color": p3s.BufferAttribute(c)}
|
||||||
|
if coloring == "FaceColors":
|
||||||
|
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
||||||
|
for ii in range(f.shape[0]):
|
||||||
|
# print(ii*3, f[ii])
|
||||||
|
verts[ii * 3] = v[f[ii, 0]]
|
||||||
|
verts[ii * 3 + 1] = v[f[ii, 1]]
|
||||||
|
verts[ii * 3 + 2] = v[f[ii, 2]]
|
||||||
|
v = verts
|
||||||
|
else:
|
||||||
|
f = f.astype("uint32", copy=False).ravel()
|
||||||
|
ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
|
||||||
|
|
||||||
|
ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
|
||||||
|
|
||||||
|
if uv is not None:
|
||||||
|
uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
|
||||||
|
if texture_data is None:
|
||||||
|
texture_data = gen_checkers(20, 20)
|
||||||
|
tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
|
||||||
|
material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
|
||||||
|
roughness=sh["roughness"], metalness=sh["metalness"],
|
||||||
|
flatShading=sh["flat"],
|
||||||
|
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
||||||
|
ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
|
||||||
|
else:
|
||||||
|
material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
|
||||||
|
side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
|
||||||
|
flatShading=sh["flat"],
|
||||||
|
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
||||||
|
|
||||||
|
if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
|
||||||
|
ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
|
||||||
|
|
||||||
|
geometry = p3s.BufferGeometry(attributes=ba_dict)
|
||||||
|
|
||||||
|
if coloring == "VertexColors" and type(n) == type(None):
|
||||||
|
geometry.exec_three_obj_method('computeVertexNormals')
|
||||||
|
elif coloring == "FaceColors" and type(n) == type(None):
|
||||||
|
geometry.exec_three_obj_method('computeFaceNormals')
|
||||||
|
|
||||||
|
# Mesh setup
|
||||||
|
mesh = p3s.Mesh(geometry=geometry, material=material)
|
||||||
|
|
||||||
|
# Wireframe setup
|
||||||
|
mesh_obj["wireframe"] = None
|
||||||
|
if sh["wireframe"]:
|
||||||
|
wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
|
||||||
|
wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
|
||||||
|
wireframe = p3s.LineSegments(wf_geometry, wf_material)
|
||||||
|
mesh.add(wireframe)
|
||||||
|
mesh_obj["wireframe"] = wireframe
|
||||||
|
|
||||||
|
# Bounding box setup
|
||||||
|
if sh["bbox"]:
|
||||||
|
v_box, f_box = self.__get_bbox(v)
|
||||||
|
_, bbox = self.add_edges(v_box, f_box, sh, mesh)
|
||||||
|
mesh_obj["bbox"] = [bbox, v_box, f_box]
|
||||||
|
|
||||||
|
# Object setup
|
||||||
|
mesh_obj["max"] = np.max(v, axis=0)
|
||||||
|
mesh_obj["min"] = np.min(v, axis=0)
|
||||||
|
mesh_obj["geometry"] = geometry
|
||||||
|
mesh_obj["mesh"] = mesh
|
||||||
|
mesh_obj["material"] = material
|
||||||
|
mesh_obj["type"] = "Mesh"
|
||||||
|
mesh_obj["shading"] = sh
|
||||||
|
mesh_obj["coloring"] = coloring
|
||||||
|
mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
|
||||||
|
|
||||||
|
return self.__add_object(mesh_obj)
|
||||||
|
|
||||||
|
def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
|
||||||
|
shading.update(kwargs)
|
||||||
|
if len(beginning.shape) == 1:
|
||||||
|
if len(beginning) == 2:
|
||||||
|
beginning = np.array([[beginning[0], beginning[1], 0]])
|
||||||
|
else:
|
||||||
|
if beginning.shape[1] == 2:
|
||||||
|
beginning = np.append(
|
||||||
|
beginning, np.zeros([beginning.shape[0], 1]), 1)
|
||||||
|
if len(ending.shape) == 1:
|
||||||
|
if len(ending) == 2:
|
||||||
|
ending = np.array([[ending[0], ending[1], 0]])
|
||||||
|
else:
|
||||||
|
if ending.shape[1] == 2:
|
||||||
|
ending = np.append(
|
||||||
|
ending, np.zeros([ending.shape[0], 1]), 1)
|
||||||
|
|
||||||
|
sh = self.__get_shading(shading)
|
||||||
|
lines = np.hstack([beginning, ending])
|
||||||
|
lines = lines.reshape((-1, 3))
|
||||||
|
return self.__add_line_geometry(lines, sh, obj)
|
||||||
|
|
||||||
|
def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
|
||||||
|
shading.update(kwargs)
|
||||||
|
if vertices.shape[1] == 2:
|
||||||
|
vertices = np.append(
|
||||||
|
vertices, np.zeros([vertices.shape[0], 1]), 1)
|
||||||
|
sh = self.__get_shading(shading)
|
||||||
|
lines = np.zeros((edges.size, 3))
|
||||||
|
cnt = 0
|
||||||
|
for e in edges:
|
||||||
|
lines[cnt, :] = vertices[e[0]]
|
||||||
|
lines[cnt + 1, :] = vertices[e[1]]
|
||||||
|
cnt += 2
|
||||||
|
return self.__add_line_geometry(lines, sh, obj)
|
||||||
|
|
||||||
|
def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
|
||||||
|
shading.update(kwargs)
|
||||||
|
if len(points.shape) == 1:
|
||||||
|
if len(points) == 2:
|
||||||
|
points = np.array([[points[0], points[1], 0]])
|
||||||
|
else:
|
||||||
|
if points.shape[1] == 2:
|
||||||
|
points = np.append(
|
||||||
|
points, np.zeros([points.shape[0], 1]), 1)
|
||||||
|
sh = self.__get_shading(shading)
|
||||||
|
points = points.astype("float32", copy=False)
|
||||||
|
mi = np.min(points, axis=0)
|
||||||
|
ma = np.max(points, axis=0)
|
||||||
|
|
||||||
|
g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
|
||||||
|
m_attributes = {"size": sh["point_size"]}
|
||||||
|
|
||||||
|
if sh["point_shape"] == "circle": # Plot circles
|
||||||
|
tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
|
||||||
|
m_attributes["map"] = tex
|
||||||
|
m_attributes["alphaTest"] = 0.5
|
||||||
|
m_attributes["transparency"] = True
|
||||||
|
else: # Plot squares
|
||||||
|
pass
|
||||||
|
|
||||||
|
colors, v_colors = self.__get_point_colors(points, c, sh)
|
||||||
|
if v_colors: # Colors per point
|
||||||
|
m_attributes["vertexColors"] = 'VertexColors'
|
||||||
|
g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
|
||||||
|
|
||||||
|
else: # Colors for all points
|
||||||
|
m_attributes["color"] = colors
|
||||||
|
|
||||||
|
material = p3s.PointsMaterial(**m_attributes)
|
||||||
|
geometry = p3s.BufferGeometry(attributes=g_attributes)
|
||||||
|
points = p3s.Points(geometry=geometry, material=material)
|
||||||
|
point_obj = {"geometry": geometry, "mesh": points, "material": material,
|
||||||
|
"max": ma, "min": mi, "type": "Points", "wireframe": None}
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
return self.__add_object(point_obj, obj), point_obj
|
||||||
|
else:
|
||||||
|
return self.__add_object(point_obj)
|
||||||
|
|
||||||
|
def remove_object(self, obj_id):
|
||||||
|
if obj_id not in self.__objects:
|
||||||
|
print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
|
||||||
|
return
|
||||||
|
self._scene.remove(self.__objects[obj_id]["mesh"])
|
||||||
|
del self.__objects[obj_id]
|
||||||
|
self.__update_view()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for obj_id in list(self.__objects.keys()).copy():
|
||||||
|
self._scene.remove(self.__objects[obj_id]["mesh"])
|
||||||
|
del self.__objects[obj_id]
|
||||||
|
self.__update_view()
|
||||||
|
|
||||||
|
def update_object(self, oid=0, vertices=None, colors=None, faces=None):
|
||||||
|
obj = self.__objects[oid]
|
||||||
|
if type(vertices) != type(None):
|
||||||
|
if obj["coloring"] == "FaceColors":
|
||||||
|
f = obj["arrays"][1]
|
||||||
|
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
||||||
|
for ii in range(f.shape[0]):
|
||||||
|
# print(ii*3, f[ii])
|
||||||
|
verts[ii * 3] = vertices[f[ii, 0]]
|
||||||
|
verts[ii * 3 + 1] = vertices[f[ii, 1]]
|
||||||
|
verts[ii * 3 + 2] = vertices[f[ii, 2]]
|
||||||
|
v = verts
|
||||||
|
|
||||||
|
else:
|
||||||
|
v = vertices.astype("float32", copy=False)
|
||||||
|
obj["geometry"].attributes["position"].array = v
|
||||||
|
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
||||||
|
obj["geometry"].attributes["position"].needsUpdate = True
|
||||||
|
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
||||||
|
if type(colors) != type(None):
|
||||||
|
colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
|
||||||
|
colors = colors.astype("float32", copy=False)
|
||||||
|
obj["geometry"].attributes["color"].array = colors
|
||||||
|
obj["geometry"].attributes["color"].needsUpdate = True
|
||||||
|
if type(faces) != type(None):
|
||||||
|
if obj["coloring"] == "FaceColors":
|
||||||
|
print("Face updates are currently only possible in vertex color mode.")
|
||||||
|
return
|
||||||
|
f = faces.astype("uint32", copy=False).ravel()
|
||||||
|
print(obj["geometry"].attributes)
|
||||||
|
obj["geometry"].attributes["index"].array = f
|
||||||
|
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
||||||
|
obj["geometry"].attributes["index"].needsUpdate = True
|
||||||
|
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
||||||
|
# self.mesh.geometry.verticesNeedUpdate = True
|
||||||
|
# self.mesh.geometry.elementsNeedUpdate = True
|
||||||
|
# self.update()
|
||||||
|
if self.render_mode == "WEBSITE":
|
||||||
|
return self
|
||||||
|
|
||||||
|
# def update(self):
|
||||||
|
# self.mesh.exec_three_obj_method('update')
|
||||||
|
# self.orbit.exec_three_obj_method('update')
|
||||||
|
# self.cam.exec_three_obj_method('updateProjectionMatrix')
|
||||||
|
# self.scene.exec_three_obj_method('update')
|
||||||
|
|
||||||
|
def add_text(self, text, shading={}, **kwargs):
|
||||||
|
shading.update(kwargs)
|
||||||
|
sh = self.__get_shading(shading)
|
||||||
|
tt = p3s.TextTexture(string=text, color=sh["text_color"])
|
||||||
|
sm = p3s.SpriteMaterial(map=tt)
|
||||||
|
text = p3s.Sprite(material=sm, scaleToTexture=True)
|
||||||
|
self._scene.add(text)
|
||||||
|
|
||||||
|
# def add_widget(self, widget, callback):
|
||||||
|
# self.widgets.append(widget)
|
||||||
|
# widget.observe(callback, names='value')
|
||||||
|
|
||||||
|
# def add_dropdown(self, options, default, desc, cb):
|
||||||
|
# widget = widgets.Dropdown(options=options, value=default, description=desc)
|
||||||
|
# self.__widgets.append(widget)
|
||||||
|
# widget.observe(cb, names="value")
|
||||||
|
# display(widget)
|
||||||
|
|
||||||
|
# def add_button(self, text, cb):
|
||||||
|
# button = widgets.Button(description=text)
|
||||||
|
# self.__widgets.append(button)
|
||||||
|
# button.on_click(cb)
|
||||||
|
# display(button)
|
||||||
|
|
||||||
|
def to_html(self, imports=True, html_frame=True):
|
||||||
|
# Bake positions (fixes centering bug in offline rendering)
|
||||||
|
if len(self.__objects) == 0:
|
||||||
|
return
|
||||||
|
ma = np.zeros((len(self.__objects), 3))
|
||||||
|
mi = np.zeros((len(self.__objects), 3))
|
||||||
|
for r, obj in enumerate(self.__objects):
|
||||||
|
ma[r] = self.__objects[obj]["max"]
|
||||||
|
mi[r] = self.__objects[obj]["min"]
|
||||||
|
ma = np.max(ma, axis=0)
|
||||||
|
mi = np.min(mi, axis=0)
|
||||||
|
diag = np.linalg.norm(ma - mi)
|
||||||
|
mean = (ma - mi) / 2 + mi
|
||||||
|
for r, obj in enumerate(self.__objects):
|
||||||
|
v = self.__objects[obj]["geometry"].attributes["position"].array
|
||||||
|
v -= mean
|
||||||
|
# v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
|
||||||
|
|
||||||
|
scale = self.__s["scale"] * (diag)
|
||||||
|
self._orbit.target = [0.0, 0.0, 0.0]
|
||||||
|
self._cam.lookAt([0.0, 0.0, 0.0])
|
||||||
|
# self._cam.position = [0.0, 0.0, scale]
|
||||||
|
self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
|
||||||
|
self._light.position = [0.0, 0.0, scale]
|
||||||
|
|
||||||
|
state = embed.dependency_state(self._renderer)
|
||||||
|
|
||||||
|
# Somehow these entries are missing when the state is exported in python.
|
||||||
|
# Exporting from the GUI works, so we are inserting the missing entries.
|
||||||
|
for k in state:
|
||||||
|
if state[k]["model_name"] == "OrbitControlsModel":
|
||||||
|
state[k]["state"]["maxAzimuthAngle"] = "inf"
|
||||||
|
state[k]["state"]["maxDistance"] = "inf"
|
||||||
|
state[k]["state"]["maxZoom"] = "inf"
|
||||||
|
state[k]["state"]["minAzimuthAngle"] = "-inf"
|
||||||
|
|
||||||
|
tpl = embed.load_requirejs_template
|
||||||
|
if not imports:
|
||||||
|
embed.load_requirejs_template = ""
|
||||||
|
|
||||||
|
s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
|
||||||
|
# s = embed.embed_snippet(self.__w, state=state)
|
||||||
|
embed.load_requirejs_template = tpl
|
||||||
|
|
||||||
|
if html_frame:
|
||||||
|
s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
|
||||||
|
|
||||||
|
# Revert changes
|
||||||
|
for r, obj in enumerate(self.__objects):
|
||||||
|
v = self.__objects[obj]["geometry"].attributes["position"].array
|
||||||
|
v += mean
|
||||||
|
self.__update_view()
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def save(self, filename=""):
|
||||||
|
if filename == "":
|
||||||
|
uid = str(uuid.uuid4()) + ".html"
|
||||||
|
else:
|
||||||
|
filename = filename.replace(".html", "")
|
||||||
|
uid = filename + '.html'
|
||||||
|
with open(uid, "w") as f:
|
||||||
|
f.write(self.to_html())
|
||||||
|
print("Plot saved to file %s." % uid)
|
||||||
30
hy3dshape/minimal_demo.py
Normal file
30
hy3dshape/minimal_demo.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from hy3dshape.rembg import BackgroundRemover
|
||||||
|
from hy3dshape.pipelines import Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
|
||||||
|
model_path = 'tencent/Hunyuan3D-2.1'
|
||||||
|
pipeline_shapegen = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(model_path)
|
||||||
|
|
||||||
|
image_path = 'demos/demo.png'
|
||||||
|
image = Image.open(image_path).convert("RGBA")
|
||||||
|
if image.mode == 'RGB':
|
||||||
|
rembg = BackgroundRemover()
|
||||||
|
image = rembg(image)
|
||||||
|
|
||||||
|
mesh = pipeline_shapegen(image=image)[0]
|
||||||
|
mesh.export('demo.glb')
|
||||||
51
hy3dshape/minimal_vae_demo.py
Normal file
51
hy3dshape/minimal_vae_demo.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||||
|
# except for the third-party components listed below.
|
||||||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||||
|
# in the repsective licenses of these third-party components.
|
||||||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||||
|
# components and must ensure that the usage of the third party components adheres to
|
||||||
|
# all relevant laws and regulations.
|
||||||
|
|
||||||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||||
|
# their software and algorithms, including trained model weights, parameters (including
|
||||||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from hy3dshape.surface_loaders import SharpEdgeSurfaceLoader
|
||||||
|
from hy3dshape.models.autoencoders import ShapeVAE
|
||||||
|
from hy3dshape.pipelines import export_to_trimesh
|
||||||
|
|
||||||
|
|
||||||
|
vae = ShapeVAE.from_pretrained(
|
||||||
|
'tencent/Hunyuan3D-2.1',
|
||||||
|
use_safetensors=False,
|
||||||
|
variant='fp16',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
loader = SharpEdgeSurfaceLoader(
|
||||||
|
num_sharp_points=0,
|
||||||
|
num_uniform_points=81920,
|
||||||
|
)
|
||||||
|
mesh_demo = 'demos/demo.glb'
|
||||||
|
surface = loader(mesh_demo).to('cuda', dtype=torch.float16)
|
||||||
|
print(surface.shape)
|
||||||
|
|
||||||
|
latents = vae.encode(surface)
|
||||||
|
latents = vae.decode(latents)
|
||||||
|
mesh = vae.latents2mesh(
|
||||||
|
latents,
|
||||||
|
output_type='trimesh',
|
||||||
|
bounds=1.01,
|
||||||
|
mc_level=0.0,
|
||||||
|
num_chunks=20000,
|
||||||
|
octree_resolution=256,
|
||||||
|
mc_algo='mc',
|
||||||
|
enable_pbar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh = export_to_trimesh(mesh)[0]
|
||||||
|
mesh.export('output.obj')
|
||||||
127
nodes.py
127
nodes.py
@ -139,6 +139,54 @@ class Hy3DModelLoader:
|
|||||||
|
|
||||||
return (pipe, vae,)
|
return (pipe, vae,)
|
||||||
|
|
||||||
|
class Hy3D_2_1SimpleMeshGen:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}),
|
||||||
|
"image": ("IMAGE", {"tooltip": "Image to generate mesh from"}),
|
||||||
|
"steps": ("INT", {"default": 50, "min": 1, "max": 100, "step": 1, "tooltip": "Number of diffusion steps"}),
|
||||||
|
"guidance_scale": ("FLOAT", {"default": 5.0, "min": 1, "max": 30, "step": 0.1, "tooltip": "Guidance scale"}),
|
||||||
|
"octree_resolution": ("INT", {"default": 384, "min": 32, "max": 1024, "step": 32, "tooltip": "Octree resolution"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("TRIMESH", )
|
||||||
|
RETURN_NAMES = ("trimesh",)
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
|
def loadmodel(self, model, image, steps, guidance_scale, octree_resolution):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device=mm.unet_offload_device()
|
||||||
|
|
||||||
|
from .hy3dshape.hy3dshape.pipelines import Hunyuan3DDiTFlowMatchingPipeline
|
||||||
|
from .hy3dshape.hy3dshape.rembg import BackgroundRemover
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
model_path = folder_paths.get_full_path("diffusion_models", model)
|
||||||
|
if not hasattr(self, "pipeline"):
|
||||||
|
self.pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
|
||||||
|
config_path=os.path.join(script_directory, 'configs', 'dit_config_2_1.yaml'),
|
||||||
|
ckpt_path=model_path)
|
||||||
|
|
||||||
|
to_pil = T.ToPILImage()
|
||||||
|
image = to_pil(image[0].permute(2, 0, 1))
|
||||||
|
|
||||||
|
if image.mode == 'RGB':
|
||||||
|
rembg = BackgroundRemover()
|
||||||
|
image = rembg(image)
|
||||||
|
|
||||||
|
mesh = self.pipeline(
|
||||||
|
image=image,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
octree_resolution=octree_resolution
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return (mesh,)
|
||||||
|
|
||||||
class Hy3DVAELoader:
|
class Hy3DVAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -315,7 +363,7 @@ class DownloadAndLoadHy3DPaintModel:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": (["hunyuan3d-paint-v2-0"],),
|
"model": (["hunyuan3d-paint-v2-0", "hunyuan3d-paint-v2-0-turbo"],),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
"compile_args": ("HY3DCOMPILEARGS", {"tooltip": "torch.compile settings, when connected to the model loader, torch.compile of the selected models is attempted. Requires Triton and torch 2.5.0 is recommended"}),
|
||||||
@ -340,7 +388,7 @@ class DownloadAndLoadHy3DPaintModel:
|
|||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="tencent/Hunyuan3D-2",
|
repo_id="tencent/Hunyuan3D-2",
|
||||||
allow_patterns=[f"*{model}*"],
|
allow_patterns=[f"*{model}*"],
|
||||||
ignore_patterns=["*diffusion_pytorch_model.bin"],
|
ignore_patterns=["*unet/diffusion_pytorch_model.bin", "*image_encoder*"],
|
||||||
local_dir=download_path,
|
local_dir=download_path,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
@ -1031,11 +1079,62 @@ class Hy3DLoadMesh:
|
|||||||
DESCRIPTION = "Loads a glb model from the given path."
|
DESCRIPTION = "Loads a glb model from the given path."
|
||||||
|
|
||||||
def load(self, glb_path):
|
def load(self, glb_path):
|
||||||
|
|
||||||
|
if not os.path.exists(glb_path):
|
||||||
|
glb_path = os.path.join(folder_paths.get_input_directory(), glb_path)
|
||||||
|
|
||||||
trimesh = Trimesh.load(glb_path, force="mesh")
|
trimesh = Trimesh.load(glb_path, force="mesh")
|
||||||
|
|
||||||
return (trimesh,)
|
return (trimesh,)
|
||||||
|
|
||||||
|
|
||||||
|
class TrimeshToMESH:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"trimesh": ("TRIMESH",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("MESH",)
|
||||||
|
OUTPUT_TOOLTIPS = ("MESH object containing vertices and faces as torch tensors.",)
|
||||||
|
|
||||||
|
FUNCTION = "load"
|
||||||
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
DESCRIPTION = "Converts trimesh object to ComfyUI MESH object, which only includes mesh data"
|
||||||
|
|
||||||
|
def load(self, trimesh):
|
||||||
|
|
||||||
|
vertices = torch.tensor(trimesh.vertices, dtype=torch.float32)
|
||||||
|
faces = torch.tensor(trimesh.faces, dtype=torch.float32)
|
||||||
|
mesh = (self.MESH(vertices.unsqueeze(0), faces.unsqueeze(0)))
|
||||||
|
|
||||||
|
return (mesh,)
|
||||||
|
|
||||||
|
class MESH:
|
||||||
|
def __init__(self, vertices, faces):
|
||||||
|
self.vertices = vertices
|
||||||
|
self.faces = faces
|
||||||
|
|
||||||
|
class MESHToTrimesh:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mesh": ("MESH",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("TRIMESH",)
|
||||||
|
OUTPUT_TOOLTIPS = ("TRIMESH object containing vertices and faces as torch tensors.",)
|
||||||
|
|
||||||
|
FUNCTION = "load"
|
||||||
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
DESCRIPTION = "Converts trimesh object to ComfyUI MESH object, which only includes mesh data"
|
||||||
|
|
||||||
|
def load(self, mesh):
|
||||||
|
mesh_output = Trimesh.Trimesh(mesh.vertices[0], mesh.faces[0])
|
||||||
|
return (mesh_output,)
|
||||||
|
|
||||||
class Hy3DUploadMesh:
|
class Hy3DUploadMesh:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1178,14 +1277,22 @@ class Hy3DGenerateMeshMultiView():
|
|||||||
|
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
if front is not None:
|
if front is not None and not torch.all(front < 1e-6).item():
|
||||||
front = front.clone().permute(0, 3, 1, 2).to(device)
|
front = front.clone().permute(0, 3, 1, 2).to(device)
|
||||||
if back is not None:
|
else:
|
||||||
|
front = None
|
||||||
|
if back is not None and not torch.all(back < 1e-6).item():
|
||||||
back = back.clone().permute(0, 3, 1, 2).to(device)
|
back = back.clone().permute(0, 3, 1, 2).to(device)
|
||||||
if left is not None:
|
else:
|
||||||
|
back = None
|
||||||
|
if left is not None and not torch.all(left < 1e-6).item():
|
||||||
left = left.clone().permute(0, 3, 1, 2).to(device)
|
left = left.clone().permute(0, 3, 1, 2).to(device)
|
||||||
if right is not None:
|
else:
|
||||||
|
left = None
|
||||||
|
if right is not None and not torch.all(right < 1e-6).item():
|
||||||
right = right.clone().permute(0, 3, 1, 2).to(device)
|
right = right.clone().permute(0, 3, 1, 2).to(device)
|
||||||
|
else:
|
||||||
|
right = None
|
||||||
|
|
||||||
view_dict = {
|
view_dict = {
|
||||||
'front': front,
|
'front': front,
|
||||||
@ -1774,6 +1881,7 @@ class Hy3DNvdiffrastRenderer:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Hy3DModelLoader": Hy3DModelLoader,
|
"Hy3DModelLoader": Hy3DModelLoader,
|
||||||
|
"Hy3D_2_1SimpleMeshGen": Hy3D_2_1SimpleMeshGen,
|
||||||
"Hy3DVAELoader": Hy3DVAELoader,
|
"Hy3DVAELoader": Hy3DVAELoader,
|
||||||
"Hy3DGenerateMesh": Hy3DGenerateMesh,
|
"Hy3DGenerateMesh": Hy3DGenerateMesh,
|
||||||
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
|
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
|
||||||
@ -1805,10 +1913,13 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"Hy3DMeshInfo": Hy3DMeshInfo,
|
"Hy3DMeshInfo": Hy3DMeshInfo,
|
||||||
"Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh,
|
"Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh,
|
||||||
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer,
|
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer,
|
||||||
|
"TrimeshToMESH": TrimeshToMESH,
|
||||||
|
"MESHToTrimesh": MESHToTrimesh,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"Hy3DModelLoader": "Hy3DModelLoader",
|
"Hy3DModelLoader": "Hy3DModelLoader",
|
||||||
|
"Hy3D_2_1SimpleMeshGen": "Hy3D_2_1SimpleMeshGen",
|
||||||
#"Hy3DVAELoader": "Hy3DVAELoader",
|
#"Hy3DVAELoader": "Hy3DVAELoader",
|
||||||
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
|
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
|
||||||
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
|
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
|
||||||
@ -1839,5 +1950,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"Hy3DBPT": "Hy3D BPT",
|
"Hy3DBPT": "Hy3D BPT",
|
||||||
"Hy3DMeshInfo": "Hy3D Mesh Info",
|
"Hy3DMeshInfo": "Hy3D Mesh Info",
|
||||||
"Hy3DFastSimplifyMesh": "Hy3D Fast Simplify Mesh",
|
"Hy3DFastSimplifyMesh": "Hy3D Fast Simplify Mesh",
|
||||||
"Hy3DNvdiffrastRenderer": "Hy3D Nvdiffrast Renderer"
|
"Hy3DNvdiffrastRenderer": "Hy3D Nvdiffrast Renderer",
|
||||||
|
"TrimeshToMESH": "Trimesh to MESH",
|
||||||
|
"MESHToTrimesh": "MESH to Trimesh",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-hunyan3dwrapper"
|
name = "comfyui-hunyuan3dwrapper"
|
||||||
description = "Wrapper nodes for https://github.com/Tencent/Hunyuan3D-2, additional installation steps needed, please check the github repository"
|
description = "Wrapper nodes for https://github.com/Tencent/Hunyuan3D-2, additional installation steps needed, please check the github repository"
|
||||||
version = "1.0.5"
|
version = "1.0.6"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = ["trimesh", "diffusers>=0.31.0","accelerate","huggingface_hub","einops","opencv-python","transformers","xatlas","pymeshlab","pygltflib","scikit-learn","scikit-image","pybind11"]
|
dependencies = ["trimesh", "diffusers>=0.31.0","accelerate","huggingface_hub","einops","opencv-python","transformers","xatlas","pymeshlab","pygltflib","scikit-learn","scikit-image","pybind11"]
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user