mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-01-23 13:24:26 +08:00
Fix old model loading
This commit is contained in:
parent
48c7716b2e
commit
cbe2837e50
@ -160,37 +160,21 @@ class Hunyuan3DDiTPipeline:
|
||||
scheduler="FlowMatchEulerDiscreteScheduler",
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# # load ckpt
|
||||
# if use_safetensors:
|
||||
# ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
||||
# if not os.path.exists(ckpt_path):
|
||||
# raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
||||
# logger.info(f"Loading model from {ckpt_path}")
|
||||
|
||||
# if use_safetensors:
|
||||
# # parse safetensors
|
||||
# import safetensors.torch
|
||||
# safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
||||
# ckpt = {}
|
||||
# for key, value in safetensors_ckpt.items():
|
||||
# model_name = key.split('.')[0]
|
||||
# new_key = key[len(model_name) + 1:]
|
||||
# if model_name not in ckpt:
|
||||
# ckpt[model_name] = {}
|
||||
# ckpt[model_name][new_key] = value
|
||||
# else:
|
||||
# ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
ckpt = load_torch_file(ckpt_path)
|
||||
for k in ckpt:
|
||||
print(k)
|
||||
|
||||
new_sd = {}
|
||||
sd = load_torch_file(ckpt_path)
|
||||
if ckpt_path.endswith('.safetensors'):
|
||||
for key, value in sd.items():
|
||||
model_name = key.split('.')[0]
|
||||
new_key = key[len(model_name) + 1:]
|
||||
if model_name not in new_sd:
|
||||
new_sd[model_name] = {}
|
||||
new_sd[model_name][new_key] = value
|
||||
|
||||
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# load config
|
||||
|
||||
single_block_nums = set()
|
||||
for k in ckpt["model"].keys():
|
||||
for k in new_sd["model"].keys():
|
||||
if k.startswith('single_blocks.'):
|
||||
block_num = int(k.split('.')[1])
|
||||
single_block_nums.add(block_num)
|
||||
@ -205,7 +189,7 @@ class Hunyuan3DDiTPipeline:
|
||||
|
||||
|
||||
# load model
|
||||
if "guidance_in.in_layer.bias" in ckpt['model']: #guidance_in.in_layer.bias
|
||||
if "guidance_in.in_layer.bias" in new_sd['model']: #guidance_in.in_layer.bias
|
||||
logger.info("Model has guidance_in, setting guidance_embed to True")
|
||||
config['model']['params']['guidance_embed'] = True
|
||||
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
||||
@ -221,15 +205,15 @@ class Hunyuan3DDiTPipeline:
|
||||
conditioner = instantiate_from_config(config['conditioner'])
|
||||
#model
|
||||
for name, param in model.named_parameters():
|
||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=new_sd['model'][name])
|
||||
#vae
|
||||
for name, param in vae.named_parameters():
|
||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=new_sd['vae'][name])
|
||||
|
||||
if 'conditioner' in ckpt:
|
||||
if 'conditioner' in new_sd:
|
||||
#conditioner.load_state_dict(ckpt['conditioner'])
|
||||
for name, param in conditioner.named_parameters():
|
||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=new_sd['conditioner'][name])
|
||||
|
||||
image_processor = instantiate_from_config(config['image_processor'])
|
||||
|
||||
@ -261,49 +245,6 @@ class Hunyuan3DDiTPipeline:
|
||||
|
||||
return cls(**model_kwargs), vae
|
||||
|
||||
# @classmethod
|
||||
# def from_pretrained(
|
||||
# cls,
|
||||
# model_path,
|
||||
# ckpt_name='model.ckpt',
|
||||
# config_name='config.yaml',
|
||||
# device='cuda',
|
||||
# dtype=torch.float16,
|
||||
# use_safetensors=None,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# original_model_path = model_path
|
||||
# if not os.path.exists(model_path):
|
||||
# # try local path
|
||||
# base_dir = "checkpoints"
|
||||
# model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
|
||||
# if not os.path.exists(model_path):
|
||||
# try:
|
||||
# import huggingface_hub
|
||||
# # download from huggingface
|
||||
# huggingface_hub.snapshot_download(
|
||||
# repo_id="tencent/Hunyuan3D-2",
|
||||
# local_dir=base_dir,)
|
||||
|
||||
# except ImportError:
|
||||
# logger.warning(
|
||||
# "You need to install HuggingFace Hub to load models from the hub."
|
||||
# )
|
||||
# raise RuntimeError(f"Model path {model_path} not found")
|
||||
# if not os.path.exists(model_path):
|
||||
# raise FileNotFoundError(f"Model path {original_model_path} not found")
|
||||
|
||||
# config_path = os.path.join(model_path, config_name)
|
||||
# ckpt_path = os.path.join(model_path, ckpt_name)
|
||||
# return cls.from_single_file(
|
||||
# ckpt_path,
|
||||
# config_path,
|
||||
# device=device,
|
||||
# dtype=dtype,
|
||||
# use_safetensors=use_safetensors,
|
||||
# **kwargs
|
||||
# )
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
#vae,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user