diff --git a/nodes.py b/nodes.py index 0df9be8..9eeac06 100644 --- a/nodes.py +++ b/nodes.py @@ -13,7 +13,7 @@ from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, Flo from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler - +from .hy3dgen.shapegen.models.autoencoders import ShapeVAE from diffusers import AutoencoderKL from diffusers.schedulers import ( @@ -138,6 +138,50 @@ class Hy3DModelLoader: cublas_ops=cublas_ops) return (pipe, vae,) + +class Hy3DVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), + }, + } + + RETURN_TYPES = ("HY3DVAE",) + RETURN_NAMES = ("vae",) + FUNCTION = "loadmodel" + CATEGORY = "Hunyuan3DWrapper" + + def loadmodel(self, model_name): + device = mm.get_torch_device() + offload_device=mm.unet_offload_device() + + model_path = folder_paths.get_full_path("vae", model_name) + + vae_sd = load_torch_file(model_path) + + config = { + 'num_latents': 3072, + 'embed_dim': 64, + 'num_freqs': 8, + 'include_pi': False, + 'heads': 16, + 'width': 1024, + 'num_decoder_layers': 16, + 'qkv_bias': False, + 'qk_norm': True, + 'scale_factor': 0.9990943042622529, + 'geo_decoder_mlp_expand_ratio': 1, + 'geo_decoder_downsample_ratio': 2, + 'geo_decoder_ln_post': False + } + + vae = ShapeVAE(**config) + vae.load_state_dict(vae_sd) + vae.eval().to(torch.float16) + + return (vae,) class DownloadAndLoadHy3DDelightModel: @classmethod @@ -1713,6 +1757,7 @@ class Hy3DNvdiffrastRenderer: NODE_CLASS_MAPPINGS = { "Hy3DModelLoader": Hy3DModelLoader, + "Hy3DVAELoader": Hy3DVAELoader, "Hy3DGenerateMesh": Hy3DGenerateMesh, "Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView, "Hy3DExportMesh": Hy3DExportMesh, @@ -1742,10 +1787,12 @@ NODE_CLASS_MAPPINGS = { "Hy3DBPT": Hy3DBPT, "Hy3DMeshInfo": Hy3DMeshInfo, "Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh, - "Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer + "Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer, } + NODE_DISPLAY_NAME_MAPPINGS = { "Hy3DModelLoader": "Hy3DModelLoader", + #"Hy3DVAELoader": "Hy3DVAELoader", "Hy3DGenerateMesh": "Hy3DGenerateMesh", "Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView", "Hy3DExportMesh": "Hy3DExportMesh",