more offloading for first stage

This commit is contained in:
kijai 2025-01-22 08:54:10 +02:00
parent af09b1c3a4
commit 7242e7e22e
2 changed files with 31 additions and 10 deletions

View File

@ -38,6 +38,7 @@ from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
from comfy.utils import ProgressBar
import comfy.model_management as mm
logger = logging.getLogger(__name__)
@ -145,6 +146,7 @@ class Hunyuan3DDiTPipeline:
ckpt_path,
config_path,
device='cuda',
offload_device=torch.device('cpu'),
dtype=torch.float16,
use_safetensors=None,
**kwargs,
@ -191,6 +193,7 @@ class Hunyuan3DDiTPipeline:
conditioner=conditioner,
image_processor=image_processor,
device=device,
offload_device=offload_device,
dtype=dtype,
)
model_kwargs.update(kwargs)
@ -249,7 +252,8 @@ class Hunyuan3DDiTPipeline:
scheduler,
conditioner,
image_processor,
device='cuda',
device=torch.device('cuda'),
offload_device=torch.device('cpu'),
dtype=torch.float16,
**kwargs
):
@ -259,11 +263,13 @@ class Hunyuan3DDiTPipeline:
self.conditioner = conditioner
self.image_processor = image_processor
self.to(device, dtype)
self.main_device = device
self.offload_device = offload_device
self.to(offload_device, dtype)
def to(self, device=None, dtype=None):
if device is not None:
self.device = torch.device(device)
self.vae.to(device)
self.model.to(device)
self.conditioner.to(device)
@ -274,6 +280,7 @@ class Hunyuan3DDiTPipeline:
self.conditioner.to(dtype=dtype)
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
self.conditioner.to(self.main_device)
bsz = image.shape[0]
cond = self.conditioner(image=image, mask=mask)
@ -306,6 +313,7 @@ class Hunyuan3DDiTPipeline:
return out
cond = cat_recursive(cond, un_cond)
self.conditioner.to(self.offload_device)
return cond
def prepare_extra_step_kwargs(self, generator, eta):
@ -355,9 +363,9 @@ class Hunyuan3DDiTPipeline:
image_pts.append(image_pt)
mask_pts.append(mask_pt)
image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
image_pts = torch.cat(image_pts, dim=0).to(self.main_device, dtype=self.dtype)
if mask_pts[0] is not None:
mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
mask_pts = torch.cat(mask_pts, dim=0).to(self.main_device, dtype=self.dtype)
else:
mask_pts = None
return image_pts, mask_pts
@ -414,7 +422,7 @@ class Hunyuan3DDiTPipeline:
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
device = self.device
device = self.main_device
dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and \
getattr(self.model, 'guidance_cond_proj_dim', None) is None
@ -443,6 +451,8 @@ class Hunyuan3DDiTPipeline:
).to(device=device, dtype=latents.dtype)
comfy_pbar = ProgressBar(num_inference_steps)
self.model.to(device)
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance:
@ -478,6 +488,8 @@ class Hunyuan3DDiTPipeline:
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs)
self.model.to(self.offload_device)
mm.soft_empty_cache()
return self._export(
latents,
@ -487,6 +499,7 @@ class Hunyuan3DDiTPipeline:
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
if not output_type == "latent":
self.vae.to(self.main_device)
latents = 1. / self.vae.scale_factor * latents
latents = self.vae(latents)
outputs = self.vae.latents2mesh(
@ -497,6 +510,7 @@ class Hunyuan3DDiTPipeline:
octree_resolution=octree_resolution,
mc_algo=mc_algo,
)
self.vae.to(self.offload_device)
else:
outputs = latents
@ -531,7 +545,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
device = self.device
device = self.main_device
dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and not (
hasattr(self.model, 'guidance_embed') and

View File

@ -34,10 +34,11 @@ class Hy3DModelLoader:
def loadmodel(self, model):
device = mm.get_torch_device()
offload_device=mm.unet_offload_device()
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
model_path = folder_paths.get_full_path("diffusion_models", model)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
return (pipe,)
class DownloadAndLoadHy3DDelightModel:
@ -361,6 +362,7 @@ class Hy3DGenerateMesh:
"remove_floaters": ("BOOLEAN", {"default": True}),
"remove_degenerate_faces": ("BOOLEAN", {"default": True}),
"reduce_faces": ("BOOLEAN", {"default": True}),
"max_facenum": ("INT", {"default": 40000, "min": 1}),
},
"optional": {
"mask": ("MASK", ),
@ -372,7 +374,7 @@ class Hy3DGenerateMesh:
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces,
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces, max_facenum,
mask=None):
device = mm.get_torch_device()
@ -394,12 +396,17 @@ class Hy3DGenerateMesh:
octree_resolution=octree_resolution,
generator=torch.manual_seed(seed))[0]
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if remove_floaters:
mesh = FloaterRemover()(mesh)
log.info(f"Removed floaters, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if remove_degenerate_faces:
mesh = DegenerateFaceRemover()(mesh)
log.info(f"Removed degenerate faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
if reduce_faces:
mesh = FaceReducer()(mesh)
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
pipeline.to(offload_device)