diff --git a/hy3dgen/texgen/hunyuanpaint/unet/modules.py b/hy3dgen/texgen/hunyuanpaint/unet/modules.py index 5d16bc6..15bfbbf 100755 --- a/hy3dgen/texgen/hunyuanpaint/unet/modules.py +++ b/hy3dgen/texgen/hunyuanpaint/unet/modules.py @@ -281,12 +281,27 @@ class UNet2p5DConditionModel(torch.nn.Module): def from_pretrained(pretrained_model_name_or_path, **kwargs): torch_dtype = kwargs.pop('torch_dtype', torch.float32) config_path = os.path.join(pretrained_model_name_or_path, 'config.json') - unet_ckpt_path = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin') + unet_ckpt_path_safetensors = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.safetensors') + unet_ckpt_path_bin = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin') + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config not found at {config_path}") + with open(config_path, 'r', encoding='utf-8') as file: config = json.load(file) + unet = UNet2DConditionModel(**config) unet = UNet2p5DConditionModel(unet) - unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True) + + # Try loading safetensors first, fall back to .bin + if os.path.exists(unet_ckpt_path_safetensors): + import safetensors.torch + unet_ckpt = safetensors.torch.load_file(unet_ckpt_path_safetensors) + elif os.path.exists(unet_ckpt_path_bin): + unet_ckpt = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True) + else: + raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}") + unet.load_state_dict(unet_ckpt, strict=True) unet = unet.to(torch_dtype) return unet diff --git a/nodes.py b/nodes.py index dc7ef96..98d66d7 100644 --- a/nodes.py +++ b/nodes.py @@ -4,9 +4,45 @@ import torchvision.transforms as transforms from PIL import Image from pathlib import Path import numpy as np +import json import trimesh from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover +from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel +from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline + +from diffusers import AutoencoderKL +from diffusers.schedulers import ( + DDIMScheduler, + PNDMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + UniPCMultistepScheduler, + HeunDiscreteScheduler, + SASolverScheduler, + DEISMultistepScheduler, + LCMScheduler + ) + +scheduler_mapping = { + "DPM++": DPMSolverMultistepScheduler, + "DPM++SDE": DPMSolverMultistepScheduler, + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + "SASolverScheduler": SASolverScheduler, + "UniPCMultistepScheduler": UniPCMultistepScheduler, + "HeunDiscreteScheduler": HeunDiscreteScheduler, + "DEISMultistepScheduler": DEISMultistepScheduler, + "LCMScheduler": LCMScheduler +} +available_schedulers = list(scheduler_mapping.keys()) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device import folder_paths @@ -219,15 +255,50 @@ class DownloadAndLoadHy3DPaintModel: local_dir_use_symlinks=False, ) - from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler - custom_pipeline_path = os.path.join(script_directory, 'hy3dgen', 'texgen', 'hunyuanpaint') + torch_dtype = torch.float16 + config_path = os.path.join(model_path, 'unet', 'config.json') + unet_ckpt_path_safetensors = os.path.join(model_path, 'unet','diffusion_pytorch_model.safetensors') + unet_ckpt_path_bin = os.path.join(model_path, 'unet','diffusion_pytorch_model.bin') - pipeline = DiffusionPipeline.from_pretrained( - model_path, - custom_pipeline=custom_pipeline_path, - torch_dtype=torch.float16) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config not found at {config_path}") + + + with open(config_path, 'r', encoding='utf-8') as file: + config = json.load(file) + + with init_empty_weights(): + unet = UNet2DConditionModel(**config) + unet = UNet2p5DConditionModel(unet) + + # Try loading safetensors first, fall back to .bin + if os.path.exists(unet_ckpt_path_safetensors): + import safetensors.torch + unet_sd = safetensors.torch.load_file(unet_ckpt_path_safetensors) + elif os.path.exists(unet_ckpt_path_bin): + unet_sd = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True) + else: + raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}") + + #unet.load_state_dict(unet_ckpt, strict=True) + for name, param in unet.named_parameters(): + set_module_tensor_to_device(unet, name, device=offload_device, dtype=torch_dtype, value=unet_sd[name]) + + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", device=device, torch_dtype=torch_dtype) + clip = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch_dtype) + tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") + scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler") + feature_extractor = CLIPImageProcessor.from_pretrained(model_path, subfolder="feature_extractor") + + pipeline = HunyuanPaintPipeline( + unet=unet, + vae = vae, + text_encoder=clip, + tokenizer=tokenizer, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) - pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing='trailing') pipeline.enable_model_cpu_offload() return (pipeline,) @@ -531,6 +602,10 @@ class Hy3DSampleMultiView: }, "optional": { "camera_config": ("HY3DCAMERA",), + "scheduler": (available_schedulers, + { + "default": 'Euler A' + }), } } @@ -539,7 +614,7 @@ class Hy3DSampleMultiView: FUNCTION = "process" CATEGORY = "Hunyuan3DWrapper" - def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None): + def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None, scheduler="Euler A"): device = mm.get_torch_device() mm.soft_empty_cache() torch.manual_seed(seed) @@ -580,6 +655,18 @@ class Hy3DSampleMultiView: callback = ComfyProgressCallback(total_steps=steps) + scheduler_config = dict(pipeline.scheduler.config) + + if scheduler in scheduler_mapping: + if scheduler == "DPM++SDE": + scheduler_config["algorithm_type"] = "sde-dpmsolver++" + else: + scheduler_config.pop("algorithm_type", None) + noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config) + pipeline.scheduler = noise_scheduler + else: + raise ValueError(f"Unknown scheduler: {scheduler}") + multiview_images = pipeline( input_image, width=view_size,