mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2026-05-06 06:10:08 +08:00
Use Accelerate to load the main model faster
This commit is contained in:
parent
2286d0c0f7
commit
703edf2051
@ -37,6 +37,9 @@ from PIL import Image
|
|||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
|
||||||
@ -176,13 +179,24 @@ class Hunyuan3DDiTPipeline:
|
|||||||
else:
|
else:
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
# load model
|
# load model
|
||||||
model = instantiate_from_config(config['model'])
|
with init_empty_weights():
|
||||||
model.load_state_dict(ckpt['model'])
|
model = instantiate_from_config(config['model'])
|
||||||
vae = instantiate_from_config(config['vae'])
|
vae = instantiate_from_config(config['vae'])
|
||||||
vae.load_state_dict(ckpt['vae'])
|
conditioner = instantiate_from_config(config['conditioner'])
|
||||||
conditioner = instantiate_from_config(config['conditioner'])
|
#model
|
||||||
|
#model.load_state_dict(ckpt['model'])
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
||||||
|
#vae
|
||||||
|
#vae.load_state_dict(ckpt['vae'])
|
||||||
|
for name, param in vae.named_parameters():
|
||||||
|
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
||||||
|
|
||||||
if 'conditioner' in ckpt:
|
if 'conditioner' in ckpt:
|
||||||
conditioner.load_state_dict(ckpt['conditioner'])
|
#conditioner.load_state_dict(ckpt['conditioner'])
|
||||||
|
for name, param in conditioner.named_parameters():
|
||||||
|
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
||||||
|
|
||||||
image_processor = instantiate_from_config(config['image_processor'])
|
image_processor = instantiate_from_config(config['image_processor'])
|
||||||
scheduler = instantiate_from_config(config['scheduler'])
|
scheduler = instantiate_from_config(config['scheduler'])
|
||||||
|
|
||||||
|
|||||||
15
nodes.py
15
nodes.py
@ -13,9 +13,7 @@ from comfy.utils import load_torch_file
|
|||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
import logging
|
from .utils import log, print_memory
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
#region Model loading
|
#region Model loading
|
||||||
class Hy3DModelLoader:
|
class Hy3DModelLoader:
|
||||||
@ -388,6 +386,11 @@ class Hy3DGenerateMesh:
|
|||||||
|
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
mesh = pipeline(
|
mesh = pipeline(
|
||||||
image=image,
|
image=image,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
@ -409,6 +412,12 @@ class Hy3DGenerateMesh:
|
|||||||
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
|
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
|
||||||
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||||
|
|
||||||
|
print_memory(device)
|
||||||
|
try:
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
pipeline.to(offload_device)
|
pipeline.to(offload_device)
|
||||||
|
|
||||||
return (mesh, )
|
return (mesh, )
|
||||||
|
|||||||
24
utils.py
Normal file
24
utils.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def check_diffusers_version():
|
||||||
|
try:
|
||||||
|
version = importlib.metadata.version('diffusers')
|
||||||
|
required_version = '0.31.0'
|
||||||
|
if version < required_version:
|
||||||
|
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
raise AssertionError("diffusers is not installed.")
|
||||||
|
|
||||||
|
def print_memory(device):
|
||||||
|
memory = torch.cuda.memory_allocated(device) / 1024**3
|
||||||
|
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
|
||||||
|
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
||||||
|
log.info(f"Allocated memory: {memory=:.3f} GB")
|
||||||
|
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
|
||||||
|
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
|
||||||
|
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
|
||||||
|
#log.info(f"Memory Summary:\n{memory_summary}")
|
||||||
Loading…
x
Reference in New Issue
Block a user