mirror of
https://git.datalinker.icu/kijai/ComfyUI-Hunyuan3DWrapper.git
synced 2025-12-09 21:04:32 +08:00
more offloading for first stage
This commit is contained in:
parent
af09b1c3a4
commit
7242e7e22e
@ -38,6 +38,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -145,6 +146,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
ckpt_path,
|
ckpt_path,
|
||||||
config_path,
|
config_path,
|
||||||
device='cuda',
|
device='cuda',
|
||||||
|
offload_device=torch.device('cpu'),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
use_safetensors=None,
|
use_safetensors=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -191,6 +193,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
conditioner=conditioner,
|
conditioner=conditioner,
|
||||||
image_processor=image_processor,
|
image_processor=image_processor,
|
||||||
device=device,
|
device=device,
|
||||||
|
offload_device=offload_device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
model_kwargs.update(kwargs)
|
model_kwargs.update(kwargs)
|
||||||
@ -249,7 +252,8 @@ class Hunyuan3DDiTPipeline:
|
|||||||
scheduler,
|
scheduler,
|
||||||
conditioner,
|
conditioner,
|
||||||
image_processor,
|
image_processor,
|
||||||
device='cuda',
|
device=torch.device('cuda'),
|
||||||
|
offload_device=torch.device('cpu'),
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@ -259,11 +263,13 @@ class Hunyuan3DDiTPipeline:
|
|||||||
self.conditioner = conditioner
|
self.conditioner = conditioner
|
||||||
self.image_processor = image_processor
|
self.image_processor = image_processor
|
||||||
|
|
||||||
self.to(device, dtype)
|
self.main_device = device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
|
||||||
|
self.to(offload_device, dtype)
|
||||||
|
|
||||||
def to(self, device=None, dtype=None):
|
def to(self, device=None, dtype=None):
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.device = torch.device(device)
|
|
||||||
self.vae.to(device)
|
self.vae.to(device)
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
self.conditioner.to(device)
|
self.conditioner.to(device)
|
||||||
@ -274,6 +280,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
self.conditioner.to(dtype=dtype)
|
self.conditioner.to(dtype=dtype)
|
||||||
|
|
||||||
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
||||||
|
self.conditioner.to(self.main_device)
|
||||||
bsz = image.shape[0]
|
bsz = image.shape[0]
|
||||||
cond = self.conditioner(image=image, mask=mask)
|
cond = self.conditioner(image=image, mask=mask)
|
||||||
|
|
||||||
@ -306,6 +313,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
cond = cat_recursive(cond, un_cond)
|
cond = cat_recursive(cond, un_cond)
|
||||||
|
self.conditioner.to(self.offload_device)
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def prepare_extra_step_kwargs(self, generator, eta):
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
@ -355,9 +363,9 @@ class Hunyuan3DDiTPipeline:
|
|||||||
image_pts.append(image_pt)
|
image_pts.append(image_pt)
|
||||||
mask_pts.append(mask_pt)
|
mask_pts.append(mask_pt)
|
||||||
|
|
||||||
image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
|
image_pts = torch.cat(image_pts, dim=0).to(self.main_device, dtype=self.dtype)
|
||||||
if mask_pts[0] is not None:
|
if mask_pts[0] is not None:
|
||||||
mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
|
mask_pts = torch.cat(mask_pts, dim=0).to(self.main_device, dtype=self.dtype)
|
||||||
else:
|
else:
|
||||||
mask_pts = None
|
mask_pts = None
|
||||||
return image_pts, mask_pts
|
return image_pts, mask_pts
|
||||||
@ -414,7 +422,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
callback = kwargs.pop("callback", None)
|
callback = kwargs.pop("callback", None)
|
||||||
callback_steps = kwargs.pop("callback_steps", None)
|
callback_steps = kwargs.pop("callback_steps", None)
|
||||||
|
|
||||||
device = self.device
|
device = self.main_device
|
||||||
dtype = self.dtype
|
dtype = self.dtype
|
||||||
do_classifier_free_guidance = guidance_scale >= 0 and \
|
do_classifier_free_guidance = guidance_scale >= 0 and \
|
||||||
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
||||||
@ -443,6 +451,8 @@ class Hunyuan3DDiTPipeline:
|
|||||||
).to(device=device, dtype=latents.dtype)
|
).to(device=device, dtype=latents.dtype)
|
||||||
|
|
||||||
comfy_pbar = ProgressBar(num_inference_steps)
|
comfy_pbar = ProgressBar(num_inference_steps)
|
||||||
|
|
||||||
|
self.model.to(device)
|
||||||
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
@ -478,6 +488,8 @@ class Hunyuan3DDiTPipeline:
|
|||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
callback(step_idx, t, outputs)
|
callback(step_idx, t, outputs)
|
||||||
|
self.model.to(self.offload_device)
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
return self._export(
|
return self._export(
|
||||||
latents,
|
latents,
|
||||||
@ -487,6 +499,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
|
|
||||||
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
|
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
|
||||||
if not output_type == "latent":
|
if not output_type == "latent":
|
||||||
|
self.vae.to(self.main_device)
|
||||||
latents = 1. / self.vae.scale_factor * latents
|
latents = 1. / self.vae.scale_factor * latents
|
||||||
latents = self.vae(latents)
|
latents = self.vae(latents)
|
||||||
outputs = self.vae.latents2mesh(
|
outputs = self.vae.latents2mesh(
|
||||||
@ -497,6 +510,7 @@ class Hunyuan3DDiTPipeline:
|
|||||||
octree_resolution=octree_resolution,
|
octree_resolution=octree_resolution,
|
||||||
mc_algo=mc_algo,
|
mc_algo=mc_algo,
|
||||||
)
|
)
|
||||||
|
self.vae.to(self.offload_device)
|
||||||
else:
|
else:
|
||||||
outputs = latents
|
outputs = latents
|
||||||
|
|
||||||
@ -531,7 +545,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
|||||||
callback = kwargs.pop("callback", None)
|
callback = kwargs.pop("callback", None)
|
||||||
callback_steps = kwargs.pop("callback_steps", None)
|
callback_steps = kwargs.pop("callback_steps", None)
|
||||||
|
|
||||||
device = self.device
|
device = self.main_device
|
||||||
dtype = self.dtype
|
dtype = self.dtype
|
||||||
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
||||||
hasattr(self.model, 'guidance_embed') and
|
hasattr(self.model, 'guidance_embed') and
|
||||||
|
|||||||
13
nodes.py
13
nodes.py
@ -34,10 +34,11 @@ class Hy3DModelLoader:
|
|||||||
|
|
||||||
def loadmodel(self, model):
|
def loadmodel(self, model):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
|
offload_device=mm.unet_offload_device()
|
||||||
|
|
||||||
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
|
||||||
model_path = folder_paths.get_full_path("diffusion_models", model)
|
model_path = folder_paths.get_full_path("diffusion_models", model)
|
||||||
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device)
|
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(ckpt_path=model_path, config_path=config_path, use_safetensors=True, device=device, offload_device=offload_device)
|
||||||
return (pipe,)
|
return (pipe,)
|
||||||
|
|
||||||
class DownloadAndLoadHy3DDelightModel:
|
class DownloadAndLoadHy3DDelightModel:
|
||||||
@ -361,6 +362,7 @@ class Hy3DGenerateMesh:
|
|||||||
"remove_floaters": ("BOOLEAN", {"default": True}),
|
"remove_floaters": ("BOOLEAN", {"default": True}),
|
||||||
"remove_degenerate_faces": ("BOOLEAN", {"default": True}),
|
"remove_degenerate_faces": ("BOOLEAN", {"default": True}),
|
||||||
"reduce_faces": ("BOOLEAN", {"default": True}),
|
"reduce_faces": ("BOOLEAN", {"default": True}),
|
||||||
|
"max_facenum": ("INT", {"default": 40000, "min": 1}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"mask": ("MASK", ),
|
"mask": ("MASK", ),
|
||||||
@ -372,7 +374,7 @@ class Hy3DGenerateMesh:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "Hunyuan3DWrapper"
|
CATEGORY = "Hunyuan3DWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces,
|
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, remove_floaters, remove_degenerate_faces, reduce_faces, max_facenum,
|
||||||
mask=None):
|
mask=None):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
@ -394,12 +396,17 @@ class Hy3DGenerateMesh:
|
|||||||
octree_resolution=octree_resolution,
|
octree_resolution=octree_resolution,
|
||||||
generator=torch.manual_seed(seed))[0]
|
generator=torch.manual_seed(seed))[0]
|
||||||
|
|
||||||
|
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||||
|
|
||||||
if remove_floaters:
|
if remove_floaters:
|
||||||
mesh = FloaterRemover()(mesh)
|
mesh = FloaterRemover()(mesh)
|
||||||
|
log.info(f"Removed floaters, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||||
if remove_degenerate_faces:
|
if remove_degenerate_faces:
|
||||||
mesh = DegenerateFaceRemover()(mesh)
|
mesh = DegenerateFaceRemover()(mesh)
|
||||||
|
log.info(f"Removed degenerate faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||||
if reduce_faces:
|
if reduce_faces:
|
||||||
mesh = FaceReducer()(mesh)
|
mesh = FaceReducer()(mesh, max_facenum=max_facenum)
|
||||||
|
log.info(f"Reduced faces, resulting in {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
|
||||||
|
|
||||||
pipeline.to(offload_device)
|
pipeline.to(offload_device)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user