Fix old model loading

This commit is contained in:
kijai 2025-06-15 12:52:39 +03:00
parent 48c7716b2e
commit cbe2837e50

View File

@ -160,37 +160,21 @@ class Hunyuan3DDiTPipeline:
scheduler="FlowMatchEulerDiscreteScheduler",
**kwargs,
):
# # load ckpt
# if use_safetensors:
# ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
# if not os.path.exists(ckpt_path):
# raise FileNotFoundError(f"Model file {ckpt_path} not found")
# logger.info(f"Loading model from {ckpt_path}")
# if use_safetensors:
# # parse safetensors
# import safetensors.torch
# safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
# ckpt = {}
# for key, value in safetensors_ckpt.items():
# model_name = key.split('.')[0]
# new_key = key[len(model_name) + 1:]
# if model_name not in ckpt:
# ckpt[model_name] = {}
# ckpt[model_name][new_key] = value
# else:
# ckpt = torch.load(ckpt_path, map_location='cpu')
ckpt = load_torch_file(ckpt_path)
for k in ckpt:
print(k)
new_sd = {}
sd = load_torch_file(ckpt_path)
if ckpt_path.endswith('.safetensors'):
for key, value in sd.items():
model_name = key.split('.')[0]
new_key = key[len(model_name) + 1:]
if model_name not in new_sd:
new_sd[model_name] = {}
new_sd[model_name][new_key] = value
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# load config
single_block_nums = set()
for k in ckpt["model"].keys():
for k in new_sd["model"].keys():
if k.startswith('single_blocks.'):
block_num = int(k.split('.')[1])
single_block_nums.add(block_num)
@ -205,7 +189,7 @@ class Hunyuan3DDiTPipeline:
# load model
if "guidance_in.in_layer.bias" in ckpt['model']: #guidance_in.in_layer.bias
if "guidance_in.in_layer.bias" in new_sd['model']: #guidance_in.in_layer.bias
logger.info("Model has guidance_in, setting guidance_embed to True")
config['model']['params']['guidance_embed'] = True
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
@ -221,15 +205,15 @@ class Hunyuan3DDiTPipeline:
conditioner = instantiate_from_config(config['conditioner'])
#model
for name, param in model.named_parameters():
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=new_sd['model'][name])
#vae
for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=new_sd['vae'][name])
if 'conditioner' in ckpt:
if 'conditioner' in new_sd:
#conditioner.load_state_dict(ckpt['conditioner'])
for name, param in conditioner.named_parameters():
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=new_sd['conditioner'][name])
image_processor = instantiate_from_config(config['image_processor'])
@ -261,49 +245,6 @@ class Hunyuan3DDiTPipeline:
return cls(**model_kwargs), vae
# @classmethod
# def from_pretrained(
# cls,
# model_path,
# ckpt_name='model.ckpt',
# config_name='config.yaml',
# device='cuda',
# dtype=torch.float16,
# use_safetensors=None,
# **kwargs,
# ):
# original_model_path = model_path
# if not os.path.exists(model_path):
# # try local path
# base_dir = "checkpoints"
# model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
# if not os.path.exists(model_path):
# try:
# import huggingface_hub
# # download from huggingface
# huggingface_hub.snapshot_download(
# repo_id="tencent/Hunyuan3D-2",
# local_dir=base_dir,)
# except ImportError:
# logger.warning(
# "You need to install HuggingFace Hub to load models from the hub."
# )
# raise RuntimeError(f"Model path {model_path} not found")
# if not os.path.exists(model_path):
# raise FileNotFoundError(f"Model path {original_model_path} not found")
# config_path = os.path.join(model_path, config_name)
# ckpt_path = os.path.join(model_path, ckpt_name)
# return cls.from_single_file(
# ckpt_path,
# config_path,
# device=device,
# dtype=dtype,
# use_safetensors=use_safetensors,
# **kwargs
# )
def __init__(
self,
#vae,