Fix paint model .safetensors loading, add schedulers

This commit is contained in:
kijai 2025-01-27 15:53:17 +02:00
parent 38efaa8d1f
commit 5b19d811ab
2 changed files with 112 additions and 10 deletions

View File

@ -281,12 +281,27 @@ class UNet2p5DConditionModel(torch.nn.Module):
def from_pretrained(pretrained_model_name_or_path, **kwargs):
torch_dtype = kwargs.pop('torch_dtype', torch.float32)
config_path = os.path.join(pretrained_model_name_or_path, 'config.json')
unet_ckpt_path = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin')
unet_ckpt_path_safetensors = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.safetensors')
unet_ckpt_path_bin = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin')
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as file:
config = json.load(file)
unet = UNet2DConditionModel(**config)
unet = UNet2p5DConditionModel(unet)
unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True)
# Try loading safetensors first, fall back to .bin
if os.path.exists(unet_ckpt_path_safetensors):
import safetensors.torch
unet_ckpt = safetensors.torch.load_file(unet_ckpt_path_safetensors)
elif os.path.exists(unet_ckpt_path_bin):
unet_ckpt = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True)
else:
raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}")
unet.load_state_dict(unet_ckpt, strict=True)
unet = unet.to(torch_dtype)
return unet

103
nodes.py
View File

@ -4,9 +4,45 @@ import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
import numpy as np
import json
import trimesh
from .hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline, FaceReducer, FloaterRemover, DegenerateFaceRemover
from .hy3dgen.texgen.hunyuanpaint.unet.modules import UNet2DConditionModel, UNet2p5DConditionModel
from .hy3dgen.texgen.hunyuanpaint.pipeline import HunyuanPaintPipeline
from diffusers import AutoencoderKL
from diffusers.schedulers import (
DDIMScheduler,
PNDMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
UniPCMultistepScheduler,
HeunDiscreteScheduler,
SASolverScheduler,
DEISMultistepScheduler,
LCMScheduler
)
scheduler_mapping = {
"DPM++": DPMSolverMultistepScheduler,
"DPM++SDE": DPMSolverMultistepScheduler,
"Euler": EulerDiscreteScheduler,
"Euler A": EulerAncestralDiscreteScheduler,
"PNDM": PNDMScheduler,
"DDIM": DDIMScheduler,
"SASolverScheduler": SASolverScheduler,
"UniPCMultistepScheduler": UniPCMultistepScheduler,
"HeunDiscreteScheduler": HeunDiscreteScheduler,
"DEISMultistepScheduler": DEISMultistepScheduler,
"LCMScheduler": LCMScheduler
}
available_schedulers = list(scheduler_mapping.keys())
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
import folder_paths
@ -219,15 +255,50 @@ class DownloadAndLoadHy3DPaintModel:
local_dir_use_symlinks=False,
)
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
custom_pipeline_path = os.path.join(script_directory, 'hy3dgen', 'texgen', 'hunyuanpaint')
torch_dtype = torch.float16
config_path = os.path.join(model_path, 'unet', 'config.json')
unet_ckpt_path_safetensors = os.path.join(model_path, 'unet','diffusion_pytorch_model.safetensors')
unet_ckpt_path_bin = os.path.join(model_path, 'unet','diffusion_pytorch_model.bin')
pipeline = DiffusionPipeline.from_pretrained(
model_path,
custom_pipeline=custom_pipeline_path,
torch_dtype=torch.float16)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as file:
config = json.load(file)
with init_empty_weights():
unet = UNet2DConditionModel(**config)
unet = UNet2p5DConditionModel(unet)
# Try loading safetensors first, fall back to .bin
if os.path.exists(unet_ckpt_path_safetensors):
import safetensors.torch
unet_sd = safetensors.torch.load_file(unet_ckpt_path_safetensors)
elif os.path.exists(unet_ckpt_path_bin):
unet_sd = torch.load(unet_ckpt_path_bin, map_location='cpu', weights_only=True)
else:
raise FileNotFoundError(f"No checkpoint found at {unet_ckpt_path_safetensors} or {unet_ckpt_path_bin}")
#unet.load_state_dict(unet_ckpt, strict=True)
for name, param in unet.named_parameters():
set_module_tensor_to_device(unet, name, device=offload_device, dtype=torch_dtype, value=unet_sd[name])
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", device=device, torch_dtype=torch_dtype)
clip = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch_dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
feature_extractor = CLIPImageProcessor.from_pretrained(model_path, subfolder="feature_extractor")
pipeline = HunyuanPaintPipeline(
unet=unet,
vae = vae,
text_encoder=clip,
tokenizer=tokenizer,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing='trailing')
pipeline.enable_model_cpu_offload()
return (pipeline,)
@ -531,6 +602,10 @@ class Hy3DSampleMultiView:
},
"optional": {
"camera_config": ("HY3DCAMERA",),
"scheduler": (available_schedulers,
{
"default": 'Euler A'
}),
}
}
@ -539,7 +614,7 @@ class Hy3DSampleMultiView:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None):
def process(self, pipeline, ref_image, normal_maps, position_maps, view_size, seed, steps, camera_config=None, scheduler="Euler A"):
device = mm.get_torch_device()
mm.soft_empty_cache()
torch.manual_seed(seed)
@ -580,6 +655,18 @@ class Hy3DSampleMultiView:
callback = ComfyProgressCallback(total_steps=steps)
scheduler_config = dict(pipeline.scheduler.config)
if scheduler in scheduler_mapping:
if scheduler == "DPM++SDE":
scheduler_config["algorithm_type"] = "sde-dpmsolver++"
else:
scheduler_config.pop("algorithm_type", None)
noise_scheduler = scheduler_mapping[scheduler].from_config(scheduler_config)
pipeline.scheduler = noise_scheduler
else:
raise ValueError(f"Unknown scheduler: {scheduler}")
multiview_images = pipeline(
input_image,
width=view_size,