more offloading for first stage

This commit is contained in:
kijai 2025-01-22 08:54:10 +02:00
parent af09b1c3a4
commit 7242e7e22e
2 changed files with 31 additions and 10 deletions

View File

@ -38,6 +38,7 @@ from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
import comfy.model_management as mm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -145,6 +146,7 @@ class Hunyuan3DDiTPipeline:
ckpt_path, ckpt_path,
config_path, config_path,
device='cuda', device='cuda',
offload_device=torch.device('cpu'),
dtype=torch.float16, dtype=torch.float16,
use_safetensors=None, use_safetensors=None,
**kwargs, **kwargs,
@ -191,6 +193,7 @@ class Hunyuan3DDiTPipeline:
conditioner=conditioner, conditioner=conditioner,
image_processor=image_processor, image_processor=image_processor,
device=device, device=device,
offload_device=offload_device,
dtype=dtype, dtype=dtype,
) )
model_kwargs.update(kwargs) model_kwargs.update(kwargs)
@ -249,7 +252,8 @@ class Hunyuan3DDiTPipeline:
scheduler, scheduler,
conditioner, conditioner,
image_processor, image_processor,
device='cuda', device=torch.device('cuda'),
offload_device=torch.device('cpu'),
dtype=torch.float16, dtype=torch.float16,
**kwargs **kwargs
): ):
@ -259,11 +263,13 @@ class Hunyuan3DDiTPipeline:
self.conditioner = conditioner self.conditioner = conditioner
self.image_processor = image_processor self.image_processor = image_processor
self.to(device, dtype) self.main_device = device
self.offload_device = offload_device
self.to(offload_device, dtype)
def to(self, device=None, dtype=None): def to(self, device=None, dtype=None):
if device is not None: if device is not None:
self.device = torch.device(device)
self.vae.to(device) self.vae.to(device)
self.model.to(device) self.model.to(device)
self.conditioner.to(device) self.conditioner.to(device)
@ -274,6 +280,7 @@ class Hunyuan3DDiTPipeline:
self.conditioner.to(dtype=dtype) self.conditioner.to(dtype=dtype)
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance): def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
self.conditioner.to(self.main_device)
bsz = image.shape[0] bsz = image.shape[0]
cond = self.conditioner(image=image, mask=mask) cond = self.conditioner(image=image, mask=mask)
@ -306,6 +313,7 @@ class Hunyuan3DDiTPipeline:
return out return out
cond = cat_recursive(cond, un_cond) cond = cat_recursive(cond, un_cond)
self.conditioner.to(self.offload_device)
return cond return cond
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
@ -355,9 +363,9 @@ class Hunyuan3DDiTPipeline:
image_pts.append(image_pt) image_pts.append(image_pt)
mask_pts.append(mask_pt) mask_pts.append(mask_pt)
image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype) image_pts = torch.cat(image_pts, dim=0).to(self.main_device, dtype=self.dtype)
if mask_pts[0] is not None: if mask_pts[0] is not None:
mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype) mask_pts = torch.cat(mask_pts, dim=0).to(self.main_device, dtype=self.dtype)
else: else:
mask_pts = None mask_pts = None
return image_pts, mask_pts return image_pts, mask_pts
@ -414,7 +422,7 @@ class Hunyuan3DDiTPipeline:
callback = kwargs.pop("callback", None) callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None) callback_steps = kwargs.pop("callback_steps", None)
device = self.device device = self.main_device
dtype = self.dtype dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and \ do_classifier_free_guidance = guidance_scale >= 0 and \
getattr(self.model, 'guidance_cond_proj_dim', None) is None getattr(self.model, 'guidance_cond_proj_dim', None) is None
@ -443,6 +451,8 @@ class Hunyuan3DDiTPipeline:
).to(device=device, dtype=latents.dtype) ).to(device=device, dtype=latents.dtype)
comfy_pbar = ProgressBar(num_inference_steps) comfy_pbar = ProgressBar(num_inference_steps)
self.model.to(device)
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)): for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
@ -478,6 +488,8 @@ class Hunyuan3DDiTPipeline:
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs) callback(step_idx, t, outputs)
self.model.to(self.offload_device)
mm.soft_empty_cache()
return self._export( return self._export(
latents, latents,
@ -487,6 +499,7 @@ class Hunyuan3DDiTPipeline:
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo): def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
if not output_type == "latent": if not output_type == "latent":
self.vae.to(self.main_device)
latents = 1. / self.vae.scale_factor * latents latents = 1. / self.vae.scale_factor * latents
latents = self.vae(latents) latents = self.vae(latents)
outputs = self.vae.latents2mesh( outputs = self.vae.latents2mesh(
@ -497,6 +510,7 @@ class Hunyuan3DDiTPipeline:
octree_resolution=octree_resolution, octree_resolution=octree_resolution,
mc_algo=mc_algo, mc_algo=mc_algo,
) )
self.vae.to(self.offload_device)
else: else:
outputs = latents outputs = latents
@ -531,7 +545,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
callback = kwargs.pop("callback", None) callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None) callback_steps = kwargs.pop("callback_steps", None)
device = self.device device = self.main_device
dtype = self.dtype dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and not ( do_classifier_free_guidance = guidance_scale >= 0 and not (
hasattr(self.model, 'guidance_embed') and hasattr(self.model, 'guidance_embed') and

View File

@ -34,10 +34,11 @@ class Hy3DModelLoader:
def loadmodel(self, model): def loadmodel(self, model):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
config_path = os.path.join(script_directory, "configs", "dit_config.yaml") config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
model_path = folder_paths.get_full_path("diffusion_models", model) model_path = folder_paths.get_full_path("diffusion_models", model)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device) pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
return (pipe,) return (pipe,)
class DownloadAndLoadHy3DDelightModel: class DownloadAndLoadHy3DDelightModel:
@ -361,6 +362,7 @@ class Hy3DGenerateMesh:
"remove_floaters": ("BOOLEAN", {"default": True}), "remove_floaters": ("BOOLEAN", {"default": True}),
"remove_degenerate_faces": ("BOOLEAN", {"default": True}), "remove_degenerate_faces": ("BOOLEAN", {"default": True}),
"reduce_faces": ("BOOLEAN", {"default": True}), "reduce_faces": ("BOOLEAN", {"default": True}),
"max_facenum": ("INT", {"default": 40000, "min": 1}),
}, },
"optional": { "optional": {
"mask": ("MASK", ), "mask": ("MASK", ),
@ -372,7 +374,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces, def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces, max_facenum,
mask=None): mask=None):
device = mm.get_torch_device() device = mm.get_torch_device()
@ -394,12 +396,17 @@ class Hy3DGenerateMesh:
octree_resolution=octree_resolution, octree_resolution=octree_resolution,
generator=torch.manual_seed(seed))[0] generator=torch.manual_seed(seed))[0]
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if remove_floaters: if remove_floaters:
mesh = FloaterRemover()(mesh) mesh = FloaterRemover()(mesh)
log.info(f"Removed floaters, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if remove_degenerate_faces: if remove_degenerate_faces:
mesh = DegenerateFaceRemover()(mesh) mesh = DegenerateFaceRemover()(mesh)
log.info(f"Removed degenerate faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if reduce_faces: if reduce_faces:
mesh = FaceReducer()(mesh) mesh = FaceReducer()(mesh, max_facenum=max_facenum)
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
pipeline.to(offload_device) pipeline.to(offload_device)