mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-14 23:34:29 +08:00
Fix paint model .safetensors loading, add schedulers
This commit is contained in:
parent
38efaa8d1f
commit
5b19d811ab
@ -281,12 +281,27 @@ class UNet2p5DConditionModel(torch.nn.Module):
|
|||||||
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
||||||
torch_dtype = kwargs.pop('torch_dtype', torch.float32)
|
torch_dtype = kwargs.pop('torch_dtype', torch.float32)
|
||||||
config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
|
config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
|
||||||
unet_ckpt_path = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin')
|
unet_ckpt_path_safetensors = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.safetensors')
|
||||||
|
unet_ckpt_path_bin = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin')
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise FileNotFoundError(f"Config not found at {config_path}")
|
||||||
|
|
||||||
with open(config_path, 'r', encoding='utf-8') as file:
|
with open(config_path, 'r', encoding='utf-8') as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**config)
|
unet = UNet2DConditionModel(**config)
|
||||||
unet = UNet2p5DConditionModel(unet)
|
unet = UNet2p5DConditionModel(unet)
|
||||||
unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
|
|
||||||
|
# Try loading safetensors first, fall back to .bin
|
||||||
|
if os.path.exists(unet_ckpt_path_safetensors):
|
||||||
|
import safetensors.torch
|
||||||
|
unet_ckpt = safetensors.torch.load_file(unet_ckpt_path_safetensors)
|
||||||
|
elif os.path.exists(unet_ckpt_path_bin):
|
||||||
|
unet_ckpt = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}")
|
||||||
|
|
||||||
unet.load_state_dict(unet_ckpt, strict=True)
|
unet.load_state_dict(unet_ckpt, strict=True)
|
||||||
unet = unet.to(torch_dtype)
|
unet = unet.to(torch_dtype)
|
||||||
return unet
|
return unet
|
||||||
|
|||||||
103
nodes.py
103
nodes.py
@ -4,9 +4,45 @@ import torchvision.transforms as transforms
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
import trimesh
|
import trimesh
|
||||||
|
|
||||||
from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover
|
from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover
|
||||||
|
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
|
||||||
|
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from diffusers.schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
|
SASolverScheduler,
|
||||||
|
DEISMultistepScheduler,
|
||||||
|
LCMScheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_mapping = {
|
||||||
|
"DPM++": DPMSolverMultistepScheduler,
|
||||||
|
"DPM++SDE": DPMSolverMultistepScheduler,
|
||||||
|
"Euler": EulerDiscreteScheduler,
|
||||||
|
"Euler A": EulerAncestralDiscreteScheduler,
|
||||||
|
"PNDM": PNDMScheduler,
|
||||||
|
"DDIM": DDIMScheduler,
|
||||||
|
"SASolverScheduler": SASolverScheduler,
|
||||||
|
"UniPCMultistepScheduler": UniPCMultistepScheduler,
|
||||||
|
"HeunDiscreteScheduler": HeunDiscreteScheduler,
|
||||||
|
"DEISMultistepScheduler": DEISMultistepScheduler,
|
||||||
|
"LCMScheduler": LCMScheduler
|
||||||
|
}
|
||||||
|
available_schedulers = list(scheduler_mapping.keys())
|
||||||
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
@ -219,15 +255,50 @@ class DownloadAndLoadHy3DPaintModel:
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
torch_dtype = torch.float16
|
||||||
custom_pipeline_path = os.path.join(script_directory, 'hy3dgen', 'texgen', 'hunyuanpaint')
|
config_path = os.path.join(model_path, 'unet', 'config.json')
|
||||||
|
unet_ckpt_path_safetensors = os.path.join(model_path, 'unet','diffusion_pytorch_model.safetensors')
|
||||||
|
unet_ckpt_path_bin = os.path.join(model_path, 'unet','diffusion_pytorch_model.bin')
|
||||||
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
if not os.path.exists(config_path):
|
||||||
model_path,
|
raise FileNotFoundError(f"Config not found at {config_path}")
|
||||||
custom_pipeline=custom_pipeline_path,
|
|
||||||
torch_dtype=torch.float16)
|
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as file:
|
||||||
|
config = json.load(file)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
unet = UNet2DConditionModel(**config)
|
||||||
|
unet = UNet2p5DConditionModel(unet)
|
||||||
|
|
||||||
|
# Try loading safetensors first, fall back to .bin
|
||||||
|
if os.path.exists(unet_ckpt_path_safetensors):
|
||||||
|
import safetensors.torch
|
||||||
|
unet_sd = safetensors.torch.load_file(unet_ckpt_path_safetensors)
|
||||||
|
elif os.path.exists(unet_ckpt_path_bin):
|
||||||
|
unet_sd = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}")
|
||||||
|
|
||||||
|
#unet.load_state_dict(unet_ckpt, strict=True)
|
||||||
|
for name, param in unet.named_parameters():
|
||||||
|
set_module_tensor_to_device(unet, name, device=offload_device, dtype=torch_dtype, value=unet_sd[name])
|
||||||
|
|
||||||
|
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", device=device, torch_dtype=torch_dtype)
|
||||||
|
clip = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch_dtype)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
|
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
|
feature_extractor = CLIPImageProcessor.from_pretrained(model_path, subfolder="feature_extractor")
|
||||||
|
|
||||||
|
pipeline = HunyuanPaintPipeline(
|
||||||
|
unet=unet,
|
||||||
|
vae = vae,
|
||||||
|
text_encoder=clip,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing='trailing')
|
|
||||||
pipeline.enable_model_cpu_offload()
|
pipeline.enable_model_cpu_offload()
|
||||||
return (pipeline,)
|
return (pipeline,)
|
||||||
|
|
||||||
@ -531,6 +602,10 @@ class Hy3DSampleMultiView:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"camera_config": ("HY3DCAMERA",),
|
"camera_config": ("HY3DCAMERA",),
|
||||||
|
"scheduler": (available_schedulers,
|
||||||
|
{
|
||||||
|
"default": 'Euler A'
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -539,7 +614,7 @@ class Hy3DSampleMultiView:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "Hunyuan3DWrapper"
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None):
|
def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None, scheduler="Euler A"):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -580,6 +655,18 @@ class Hy3DSampleMultiView:
|
|||||||
|
|
||||||
callback = ComfyProgressCallback(total_steps=steps)
|
callback = ComfyProgressCallback(total_steps=steps)
|
||||||
|
|
||||||
|
scheduler_config = dict(pipeline.scheduler.config)
|
||||||
|
|
||||||
|
if scheduler in scheduler_mapping:
|
||||||
|
if scheduler == "DPM++SDE":
|
||||||
|
scheduler_config["algorithm_type"] = "sde-dpmsolver++"
|
||||||
|
else:
|
||||||
|
scheduler_config.pop("algorithm_type", None)
|
||||||
|
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
|
||||||
|
pipeline.scheduler = noise_scheduler
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown scheduler: {scheduler}")
|
||||||
|
|
||||||
multiview_images = pipeline(
|
multiview_images = pipeline(
|
||||||
input_image,
|
input_image,
|
||||||
width=view_size,
|
width=view_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user