mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
Use Accelerate to load the main model faster
This commit is contained in:
parent
2286d0c0f7
commit
703edf2051
@ -37,6 +37,9 @@ from PIL import Image
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
import comfy.model_management as mm
|
||||
|
||||
@ -176,13 +179,24 @@ class Hunyuan3DDiTPipeline:
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
# load model
|
||||
model = instantiate_from_config(config['model'])
|
||||
model.load_state_dict(ckpt['model'])
|
||||
vae = instantiate_from_config(config['vae'])
|
||||
vae.load_state_dict(ckpt['vae'])
|
||||
conditioner = instantiate_from_config(config['conditioner'])
|
||||
with init_empty_weights():
|
||||
model = instantiate_from_config(config['model'])
|
||||
vae = instantiate_from_config(config['vae'])
|
||||
conditioner = instantiate_from_config(config['conditioner'])
|
||||
#model
|
||||
#model.load_state_dict(ckpt['model'])
|
||||
for name, param in model.named_parameters():
|
||||
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
|
||||
#vae
|
||||
#vae.load_state_dict(ckpt['vae'])
|
||||
for name, param in vae.named_parameters():
|
||||
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
|
||||
|
||||
if 'conditioner' in ckpt:
|
||||
conditioner.load_state_dict(ckpt['conditioner'])
|
||||
#conditioner.load_state_dict(ckpt['conditioner'])
|
||||
for name, param in conditioner.named_parameters():
|
||||
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
|
||||
|
||||
image_processor = instantiate_from_config(config['image_processor'])
|
||||
scheduler = instantiate_from_config(config['scheduler'])
|
||||
|
||||
|
||||
15
nodes.py
15
nodes.py
@ -13,9 +13,7 @@ from comfy.utils import load_torch_file
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
from .utils import log, print_memory
|
||||
|
||||
#region Model loading
|
||||
class Hy3DModelLoader:
|
||||
@ -388,6 +386,11 @@ class Hy3DGenerateMesh:
|
||||
|
||||
pipeline.to(device)
|
||||
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
mesh = pipeline(
|
||||
image=image,
|
||||
mask=mask,
|
||||
@ -409,6 +412,12 @@ class Hy3DGenerateMesh:
|
||||
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
|
||||
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||
|
||||
print_memory(device)
|
||||
try:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
except:
|
||||
pass
|
||||
|
||||
pipeline.to(offload_device)
|
||||
|
||||
return (mesh, )
|
||||
|
||||
24
utils.py
Normal file
24
utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
import importlib.metadata
|
||||
import torch
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def check_diffusers_version():
|
||||
try:
|
||||
version = importlib.metadata.version('diffusers')
|
||||
required_version = '0.31.0'
|
||||
if version < required_version:
|
||||
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
raise AssertionError("diffusers is not installed.")
|
||||
|
||||
def print_memory(device):
|
||||
memory = torch.cuda.memory_allocated(device) / 1024**3
|
||||
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
|
||||
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
||||
log.info(f"Allocated memory: {memory=:.3f} GB")
|
||||
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
|
||||
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
|
||||
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
|
||||
#log.info(f"Memory Summary:\n{memory_summary}")
|
||||
Loading…
x
Reference in New Issue
Block a user