Add VAE loader node

This commit is contained in:
kijai 2025-03-19 14:43:18 +02:00
parent e74f261037
commit a0d5b05e5d

View File

@ -13,7 +13,7 @@ from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, Flo
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
from .hy3dgen.shapegen.models.autoencoders import ShapeVAE
from diffusers import AutoencoderKL
from diffusers.schedulers import (
@ -138,6 +138,50 @@ class Hy3DModelLoader:
cublas_ops=cublas_ops)
return (pipe, vae,)
class Hy3DVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}),
},
}
RETURN_TYPES = ("HY3DVAE",)
RETURN_NAMES = ("vae",)
FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper"
def loadmodel(self, model_name):
device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
model_path = folder_paths.get_full_path("vae", model_name)
vae_sd = load_torch_file(model_path)
config = {
'num_latents': 3072,
'embed_dim': 64,
'num_freqs': 8,
'include_pi': False,
'heads': 16,
'width': 1024,
'num_decoder_layers': 16,
'qkv_bias': False,
'qk_norm': True,
'scale_factor': 0.9990943042622529,
'geo_decoder_mlp_expand_ratio': 1,
'geo_decoder_downsample_ratio': 2,
'geo_decoder_ln_post': False
}
vae = ShapeVAE(**config)
vae.load_state_dict(vae_sd)
vae.eval().to(torch.float16)
return (vae,)
class DownloadAndLoadHy3DDelightModel:
@classmethod
@ -1713,6 +1757,7 @@ class Hy3DNvdiffrastRenderer:
NODE_CLASS_MAPPINGS = {
"Hy3DModelLoader": Hy3DModelLoader,
"Hy3DVAELoader": Hy3DVAELoader,
"Hy3DGenerateMesh": Hy3DGenerateMesh,
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
"Hy3DExportMesh": Hy3DExportMesh,
@ -1742,10 +1787,12 @@ NODE_CLASS_MAPPINGS = {
"Hy3DBPT": Hy3DBPT,
"Hy3DMeshInfo": Hy3DMeshInfo,
"Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh,
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader",
#"Hy3DVAELoader": "Hy3DVAELoader",
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
"Hy3DExportMesh": "Hy3DExportMesh",