mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-05-02 13:15:44 +08:00
Add VAE loader node
This commit is contained in:
parent
e74f261037
commit
a0d5b05e5d
51
nodes.py
51
nodes.py
@ -13,7 +13,7 @@ from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, Flo
|
|||||||
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
|
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
|
||||||
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
|
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
|
||||||
from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
|
from .hy3dgen.shapegen.schedulers import FlowMatchEulerDiscreteScheduler, ConsistencyFlowMatchEulerDiscreteScheduler
|
||||||
|
from .hy3dgen.shapegen.models.autoencoders import ShapeVAE
|
||||||
|
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers.schedulers import (
|
from diffusers.schedulers import (
|
||||||
@ -138,6 +138,50 @@ class Hy3DModelLoader:
|
|||||||
cublas_ops=cublas_ops)
|
cublas_ops=cublas_ops)
|
||||||
|
|
||||||
return (pipe, vae,)
|
return (pipe, vae,)
|
||||||
|
|
||||||
|
class Hy3DVAELoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("HY3DVAE",)
|
||||||
|
RETURN_NAMES = ("vae",)
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
|
def loadmodel(self, model_name):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device=mm.unet_offload_device()
|
||||||
|
|
||||||
|
model_path = folder_paths.get_full_path("vae", model_name)
|
||||||
|
|
||||||
|
vae_sd = load_torch_file(model_path)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'num_latents': 3072,
|
||||||
|
'embed_dim': 64,
|
||||||
|
'num_freqs': 8,
|
||||||
|
'include_pi': False,
|
||||||
|
'heads': 16,
|
||||||
|
'width': 1024,
|
||||||
|
'num_decoder_layers': 16,
|
||||||
|
'qkv_bias': False,
|
||||||
|
'qk_norm': True,
|
||||||
|
'scale_factor': 0.9990943042622529,
|
||||||
|
'geo_decoder_mlp_expand_ratio': 1,
|
||||||
|
'geo_decoder_downsample_ratio': 2,
|
||||||
|
'geo_decoder_ln_post': False
|
||||||
|
}
|
||||||
|
|
||||||
|
vae = ShapeVAE(**config)
|
||||||
|
vae.load_state_dict(vae_sd)
|
||||||
|
vae.eval().to(torch.float16)
|
||||||
|
|
||||||
|
return (vae,)
|
||||||
|
|
||||||
class DownloadAndLoadHy3DDelightModel:
|
class DownloadAndLoadHy3DDelightModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1713,6 +1757,7 @@ class Hy3DNvdiffrastRenderer:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Hy3DModelLoader": Hy3DModelLoader,
|
"Hy3DModelLoader": Hy3DModelLoader,
|
||||||
|
"Hy3DVAELoader": Hy3DVAELoader,
|
||||||
"Hy3DGenerateMesh": Hy3DGenerateMesh,
|
"Hy3DGenerateMesh": Hy3DGenerateMesh,
|
||||||
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
|
"Hy3DGenerateMeshMultiView": Hy3DGenerateMeshMultiView,
|
||||||
"Hy3DExportMesh": Hy3DExportMesh,
|
"Hy3DExportMesh": Hy3DExportMesh,
|
||||||
@ -1742,10 +1787,12 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"Hy3DBPT": Hy3DBPT,
|
"Hy3DBPT": Hy3DBPT,
|
||||||
"Hy3DMeshInfo": Hy3DMeshInfo,
|
"Hy3DMeshInfo": Hy3DMeshInfo,
|
||||||
"Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh,
|
"Hy3DFastSimplifyMesh": Hy3DFastSimplifyMesh,
|
||||||
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer
|
"Hy3DNvdiffrastRenderer": Hy3DNvdiffrastRenderer,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"Hy3DModelLoader": "Hy3DModelLoader",
|
"Hy3DModelLoader": "Hy3DModelLoader",
|
||||||
|
#"Hy3DVAELoader": "Hy3DVAELoader",
|
||||||
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
|
"Hy3DGenerateMesh": "Hy3DGenerateMesh",
|
||||||
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
|
"Hy3DGenerateMeshMultiView": "Hy3DGenerateMeshMultiView",
|
||||||
"Hy3DExportMesh": "Hy3DExportMesh",
|
"Hy3DExportMesh": "Hy3DExportMesh",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user