mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-05-02 14:09:10 +08:00
Fix old model loading
This commit is contained in:
parent
48c7716b2e
commit
cbe2837e50
@ -160,37 +160,21 @@ class Hunyuan3DDiTPipeline:
|
|||||||
scheduler="FlowMatchEulerDiscreteScheduler",
|
scheduler="FlowMatchEulerDiscreteScheduler",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
new_sd = {}
|
||||||
# # load ckpt
|
sd = load_torch_file(ckpt_path)
|
||||||
# if use_safetensors:
|
if ckpt_path.endswith('.safetensors'):
|
||||||
# ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
for key, value in sd.items():
|
||||||
# if not os.path.exists(ckpt_path):
|
model_name = key.split('.')[0]
|
||||||
# raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
new_key = key[len(model_name) + 1:]
|
||||||
# logger.info(f"Loading model from {ckpt_path}")
|
if model_name not in new_sd:
|
||||||
|
new_sd[model_name] = {}
|
||||||
# if use_safetensors:
|
new_sd[model_name][new_key] = value
|
||||||
# # parse safetensors
|
|
||||||
# import safetensors.torch
|
|
||||||
# safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
|
||||||
# ckpt = {}
|
|
||||||
# for key, value in safetensors_ckpt.items():
|
|
||||||
# model_name = key.split('.')[0]
|
|
||||||
# new_key = key[len(model_name) + 1:]
|
|
||||||
# if model_name not in ckpt:
|
|
||||||
# ckpt[model_name] = {}
|
|
||||||
# ckpt[model_name][new_key] = value
|
|
||||||
# else:
|
|
||||||
# ckpt = torch.load(ckpt_path, map_location='cpu')
|
|
||||||
ckpt = load_torch_file(ckpt_path)
|
|
||||||
for k in ckpt:
|
|
||||||
print(k)
|
|
||||||
|
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
script_directory = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
# load config
|
# load config
|
||||||
|
|
||||||
single_block_nums = set()
|
single_block_nums = set()
|
||||||
for k in ckpt["model"].keys():
|
for k in new_sd["model"].keys():
|
||||||
if k.startswith('single_blocks.'):
|
if k.startswith('single_blocks.'):
|
||||||
block_num = int(k.split('.')[1])
|
block_num = int(k.split('.')[1])
|
||||||
single_block_nums.add(block_num)
|
single_block_nums.add(block_num)
|
||||||
@ -205,7 +189,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
|
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
if "guidance_in.in_layer.bias" in ckpt['model']: #guidance_in.in_layer.bias
|
if "guidance_in.in_layer.bias" in new_sd['model']: #guidance_in.in_layer.bias
|
||||||
logger.info("Model has guidance_in, setting guidance_embed to True")
|
logger.info("Model has guidance_in, setting guidance_embed to True")
|
||||||
config['model']['params']['guidance_embed'] = True
|
config['model']['params']['guidance_embed'] = True
|
||||||
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
config['conditioner']['params']['main_image_encoder']['kwargs']['has_guidance_embed'] = True
|
||||||
@ -221,15 +205,15 @@ class Hunyuan3DDiTPipeline:
|
|||||||
conditioner = instantiate_from_config(config['conditioner'])
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
#model
|
#model
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=new_sd['model'][name])
|
||||||
#vae
|
#vae
|
||||||
for name, param in vae.named_parameters():
|
for name, param in vae.named_parameters():
|
||||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=new_sd['vae'][name])
|
||||||
|
|
||||||
if 'conditioner' in ckpt:
|
if 'conditioner' in new_sd:
|
||||||
#conditioner.load_state_dict(ckpt['conditioner'])
|
#conditioner.load_state_dict(ckpt['conditioner'])
|
||||||
for name, param in conditioner.named_parameters():
|
for name, param in conditioner.named_parameters():
|
||||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=new_sd['conditioner'][name])
|
||||||
|
|
||||||
image_processor = instantiate_from_config(config['image_processor'])
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
|
|
||||||
@ -261,49 +245,6 @@ class Hunyuan3DDiTPipeline:
|
|||||||
|
|
||||||
return cls(**model_kwargs), vae
|
return cls(**model_kwargs), vae
|
||||||
|
|
||||||
# @classmethod
|
|
||||||
# def from_pretrained(
|
|
||||||
# cls,
|
|
||||||
# model_path,
|
|
||||||
# ckpt_name='model.ckpt',
|
|
||||||
# config_name='config.yaml',
|
|
||||||
# device='cuda',
|
|
||||||
# dtype=torch.float16,
|
|
||||||
# use_safetensors=None,
|
|
||||||
# **kwargs,
|
|
||||||
# ):
|
|
||||||
# original_model_path = model_path
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# # try local path
|
|
||||||
# base_dir = "checkpoints"
|
|
||||||
# model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# try:
|
|
||||||
# import huggingface_hub
|
|
||||||
# # download from huggingface
|
|
||||||
# huggingface_hub.snapshot_download(
|
|
||||||
# repo_id="tencent/Hunyuan3D-2",
|
|
||||||
# local_dir=base_dir,)
|
|
||||||
|
|
||||||
# except ImportError:
|
|
||||||
# logger.warning(
|
|
||||||
# "You need to install HuggingFace Hub to load models from the hub."
|
|
||||||
# )
|
|
||||||
# raise RuntimeError(f"Model path {model_path} not found")
|
|
||||||
# if not os.path.exists(model_path):
|
|
||||||
# raise FileNotFoundError(f"Model path {original_model_path} not found")
|
|
||||||
|
|
||||||
# config_path = os.path.join(model_path, config_name)
|
|
||||||
# ckpt_path = os.path.join(model_path, ckpt_name)
|
|
||||||
# return cls.from_single_file(
|
|
||||||
# ckpt_path,
|
|
||||||
# config_path,
|
|
||||||
# device=device,
|
|
||||||
# dtype=dtype,
|
|
||||||
# use_safetensors=use_safetensors,
|
|
||||||
# **kwargs
|
|
||||||
# )
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
#vae,
|
#vae,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user