Use Accelerate to load the main model faster

This commit is contained in:
kijai 2025-01-22 09:35:57 +02:00
parent 2286d0c0f7
commit 703edf2051
3 changed files with 56 additions and 9 deletions

View File

@ -37,6 +37,9 @@ from PIL import Image
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from comfy.utils import ProgressBar
import comfy.model_management as mm
@ -176,13 +179,24 @@ class Hunyuan3DDiTPipeline:
else:
ckpt = torch.load(ckpt_path, map_location='cpu')
# load model
model = instantiate_from_config(config['model'])
model.load_state_dict(ckpt['model'])
vae = instantiate_from_config(config['vae'])
vae.load_state_dict(ckpt['vae'])
conditioner = instantiate_from_config(config['conditioner'])
with init_empty_weights():
model = instantiate_from_config(config['model'])
vae = instantiate_from_config(config['vae'])
conditioner = instantiate_from_config(config['conditioner'])
#model
#model.load_state_dict(ckpt['model'])
for name, param in model.named_parameters():
set_module_tensor_to_device(model, name, device=offload_device, dtype=dtype, value=ckpt['model'][name])
#vae
#vae.load_state_dict(ckpt['vae'])
for name, param in vae.named_parameters():
set_module_tensor_to_device(vae, name, device=offload_device, dtype=dtype, value=ckpt['vae'][name])
if 'conditioner' in ckpt:
conditioner.load_state_dict(ckpt['conditioner'])
#conditioner.load_state_dict(ckpt['conditioner'])
for name, param in conditioner.named_parameters():
set_module_tensor_to_device(conditioner, name, device=offload_device, dtype=dtype, value=ckpt['conditioner'][name])
image_processor = instantiate_from_config(config['image_processor'])
scheduler = instantiate_from_config(config['scheduler'])

View File

@ -13,9 +13,7 @@ from comfy.utils import load_torch_file
script_directory = os.path.dirname(os.path.abspath(__file__))
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
from .utils import log, print_memory
#region Model loading
class Hy3DModelLoader:
@ -388,6 +386,11 @@ class Hy3DGenerateMesh:
pipeline.to(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
mesh = pipeline(
image=image,
mask=mask,
@ -409,6 +412,12 @@ class Hy3DGenerateMesh:
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
print_memory(device)
try:
torch.cuda.reset_peak_memory_stats(device)
except:
pass
pipeline.to(offload_device)
return (mesh, )

24
utils.py Normal file
View File

@ -0,0 +1,24 @@
import importlib.metadata
import torch
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
def check_diffusers_version():
try:
version = importlib.metadata.version('diffusers')
required_version = '0.31.0'
if version < required_version:
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
except importlib.metadata.PackageNotFoundError:
raise AssertionError("diffusers is not installed.")
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
log.info(f"Allocated memory: {memory=:.3f} GB")
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
#log.info(f"Memory Summary:\n{memory_summary}")