mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-08 20:34:28 +08:00
more offloading for first stage
This commit is contained in:
parent
af09b1c3a4
commit
7242e7e22e
@ -38,6 +38,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
import comfy.model_management as mm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -145,6 +146,7 @@ class Hunyuan3DDiTPipeline:
|
||||
ckpt_path,
|
||||
config_path,
|
||||
device='cuda',
|
||||
offload_device=torch.device('cpu'),
|
||||
dtype=torch.float16,
|
||||
use_safetensors=None,
|
||||
**kwargs,
|
||||
@ -191,6 +193,7 @@ class Hunyuan3DDiTPipeline:
|
||||
conditioner=conditioner,
|
||||
image_processor=image_processor,
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
dtype=dtype,
|
||||
)
|
||||
model_kwargs.update(kwargs)
|
||||
@ -249,7 +252,8 @@ class Hunyuan3DDiTPipeline:
|
||||
scheduler,
|
||||
conditioner,
|
||||
image_processor,
|
||||
device='cuda',
|
||||
device=torch.device('cuda'),
|
||||
offload_device=torch.device('cpu'),
|
||||
dtype=torch.float16,
|
||||
**kwargs
|
||||
):
|
||||
@ -259,11 +263,13 @@ class Hunyuan3DDiTPipeline:
|
||||
self.conditioner = conditioner
|
||||
self.image_processor = image_processor
|
||||
|
||||
self.to(device, dtype)
|
||||
self.main_device = device
|
||||
self.offload_device = offload_device
|
||||
|
||||
self.to(offload_device, dtype)
|
||||
|
||||
def to(self, device=None, dtype=None):
|
||||
if device is not None:
|
||||
self.device = torch.device(device)
|
||||
self.vae.to(device)
|
||||
self.model.to(device)
|
||||
self.conditioner.to(device)
|
||||
@ -274,6 +280,7 @@ class Hunyuan3DDiTPipeline:
|
||||
self.conditioner.to(dtype=dtype)
|
||||
|
||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
||||
self.conditioner.to(self.main_device)
|
||||
bsz = image.shape[0]
|
||||
cond = self.conditioner(image=image, mask=mask)
|
||||
|
||||
@ -306,6 +313,7 @@ class Hunyuan3DDiTPipeline:
|
||||
return out
|
||||
|
||||
cond = cat_recursive(cond, un_cond)
|
||||
self.conditioner.to(self.offload_device)
|
||||
return cond
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@ -355,9 +363,9 @@ class Hunyuan3DDiTPipeline:
|
||||
image_pts.append(image_pt)
|
||||
mask_pts.append(mask_pt)
|
||||
|
||||
image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
|
||||
image_pts = torch.cat(image_pts, dim=0).to(self.main_device, dtype=self.dtype)
|
||||
if mask_pts[0] is not None:
|
||||
mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
|
||||
mask_pts = torch.cat(mask_pts, dim=0).to(self.main_device, dtype=self.dtype)
|
||||
else:
|
||||
mask_pts = None
|
||||
return image_pts, mask_pts
|
||||
@ -414,7 +422,7 @@ class Hunyuan3DDiTPipeline:
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
device = self.device
|
||||
device = self.main_device
|
||||
dtype = self.dtype
|
||||
do_classifier_free_guidance = guidance_scale >= 0 and \
|
||||
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
||||
@ -443,6 +451,8 @@ class Hunyuan3DDiTPipeline:
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
comfy_pbar = ProgressBar(num_inference_steps)
|
||||
|
||||
self.model.to(device)
|
||||
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
@ -478,6 +488,8 @@ class Hunyuan3DDiTPipeline:
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, outputs)
|
||||
self.model.to(self.offload_device)
|
||||
mm.soft_empty_cache()
|
||||
|
||||
return self._export(
|
||||
latents,
|
||||
@ -487,6 +499,7 @@ class Hunyuan3DDiTPipeline:
|
||||
|
||||
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
|
||||
if not output_type == "latent":
|
||||
self.vae.to(self.main_device)
|
||||
latents = 1. / self.vae.scale_factor * latents
|
||||
latents = self.vae(latents)
|
||||
outputs = self.vae.latents2mesh(
|
||||
@ -497,6 +510,7 @@ class Hunyuan3DDiTPipeline:
|
||||
octree_resolution=octree_resolution,
|
||||
mc_algo=mc_algo,
|
||||
)
|
||||
self.vae.to(self.offload_device)
|
||||
else:
|
||||
outputs = latents
|
||||
|
||||
@ -531,7 +545,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
device = self.device
|
||||
device = self.main_device
|
||||
dtype = self.dtype
|
||||
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
||||
hasattr(self.model, 'guidance_embed') and
|
||||
|
||||
13
nodes.py
13
nodes.py
@ -34,10 +34,11 @@ class Hy3DModelLoader:
|
||||
|
||||
def loadmodel(self, model):
|
||||
device = mm.get_torch_device()
|
||||
offload_device=mm.unet_offload_device()
|
||||
|
||||
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
||||
model_path = folder_paths.get_full_path("diffusion_models", model)
|
||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device)
|
||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
|
||||
return (pipe,)
|
||||
|
||||
class DownloadAndLoadHy3DDelightModel:
|
||||
@ -361,6 +362,7 @@ class Hy3DGenerateMesh:
|
||||
"remove_floaters": ("BOOLEAN", {"default": True}),
|
||||
"remove_degenerate_faces": ("BOOLEAN", {"default": True}),
|
||||
"reduce_faces": ("BOOLEAN", {"default": True}),
|
||||
"max_facenum": ("INT", {"default": 40000, "min": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
@ -372,7 +374,7 @@ class Hy3DGenerateMesh:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "Hunyuan3DWrapper"
|
||||
|
||||
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces,
|
||||
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces, max_facenum,
|
||||
mask=None):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
@ -394,12 +396,17 @@ class Hy3DGenerateMesh:
|
||||
octree_resolution=octree_resolution,
|
||||
generator=torch.manual_seed(seed))[0]
|
||||
|
||||
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||
|
||||
if remove_floaters:
|
||||
mesh = FloaterRemover()(mesh)
|
||||
log.info(f"Removed floaters, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||
if remove_degenerate_faces:
|
||||
mesh = DegenerateFaceRemover()(mesh)
|
||||
log.info(f"Removed degenerate faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||
if reduce_faces:
|
||||
mesh = FaceReducer()(mesh)
|
||||
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
|
||||
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||
|
||||
pipeline.to(offload_device)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user