Separate VAE (old workflows need updating)

This commit is contained in:
kijai 2025-01-26 02:48:37 +02:00
parent 7e234a017e
commit d563280607
4 changed files with 855 additions and 735 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 MiB

After

Width:  |  Height:  |  Size: 6.9 MiB

View File

@ -212,7 +212,7 @@ class Hunyuan3DDiTPipeline:
vae = torch.compile(vae)
model_kwargs = dict(
vae=vae,
#vae=vae,
model=model,
scheduler=scheduler,
conditioner=conditioner,
@ -223,56 +223,54 @@ class Hunyuan3DDiTPipeline:
)
model_kwargs.update(kwargs)
return cls(
**model_kwargs
)
return cls(**model_kwargs), vae
@classmethod
def from_pretrained(
cls,
model_path,
ckpt_name='model.ckpt',
config_name='config.yaml',
device='cuda',
dtype=torch.float16,
use_safetensors=None,
**kwargs,
):
original_model_path = model_path
if not os.path.exists(model_path):
# try local path
base_dir = "checkpoints"
model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
if not os.path.exists(model_path):
try:
import huggingface_hub
# download from huggingface
huggingface_hub.snapshot_download(
repo_id="tencent/Hunyuan3D-2",
local_dir=base_dir,)
# @classmethod
# def from_pretrained(
# cls,
# model_path,
# ckpt_name='model.ckpt',
# config_name='config.yaml',
# device='cuda',
# dtype=torch.float16,
# use_safetensors=None,
# **kwargs,
# ):
# original_model_path = model_path
# if not os.path.exists(model_path):
# # try local path
# base_dir = "checkpoints"
# model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
# if not os.path.exists(model_path):
# try:
# import huggingface_hub
# # download from huggingface
# huggingface_hub.snapshot_download(
# repo_id="tencent/Hunyuan3D-2",
# local_dir=base_dir,)
except ImportError:
logger.warning(
"You need to install HuggingFace Hub to load models from the hub."
)
raise RuntimeError(f"Model path {model_path} not found")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path {original_model_path} not found")
# except ImportError:
# logger.warning(
# "You need to install HuggingFace Hub to load models from the hub."
# )
# raise RuntimeError(f"Model path {model_path} not found")
# if not os.path.exists(model_path):
# raise FileNotFoundError(f"Model path {original_model_path} not found")
config_path = os.path.join(model_path, config_name)
ckpt_path = os.path.join(model_path, ckpt_name)
return cls.from_single_file(
ckpt_path,
config_path,
device=device,
dtype=dtype,
use_safetensors=use_safetensors,
**kwargs
)
# config_path = os.path.join(model_path, config_name)
# ckpt_path = os.path.join(model_path, ckpt_name)
# return cls.from_single_file(
# ckpt_path,
# config_path,
# device=device,
# dtype=dtype,
# use_safetensors=use_safetensors,
# **kwargs
# )
def __init__(
self,
vae,
#vae,
model,
scheduler,
conditioner,
@ -282,7 +280,7 @@ class Hunyuan3DDiTPipeline:
dtype=torch.float16,
**kwargs
):
self.vae = vae
#self.vae = vae
self.model = model
self.scheduler = scheduler
self.conditioner = conditioner
@ -295,12 +293,12 @@ class Hunyuan3DDiTPipeline:
def to(self, device=None, dtype=None):
if device is not None:
self.vae.to(device)
#self.vae.to(device)
self.model.to(device)
self.conditioner.to(device)
if dtype is not None:
self.dtype = dtype
self.vae.to(dtype=dtype)
#self.vae.to(dtype=dtype)
self.model.to(dtype=dtype)
self.conditioner.to(dtype=dtype)
@ -359,7 +357,10 @@ class Hunyuan3DDiTPipeline:
return extra_step_kwargs
def prepare_latents(self, batch_size, dtype, device, generator, latents=None):
shape = (batch_size, *self.vae.latent_shape)
#shape = (batch_size, *self.vae.latent_shape)
num_latents = 3072
embed_dim = 64
shape = (batch_size, num_latents, embed_dim)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@ -423,126 +424,126 @@ class Hunyuan3DDiTPipeline:
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@torch.no_grad()
def __call__(
self,
image: Union[str, List[str], Image.Image] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
eta: float = 0.0,
guidance_scale: float = 7.5,
dual_guidance_scale: float = 10.5,
dual_guidance: bool = True,
generator=None,
box_v=1.01,
octree_resolution=384,
mc_level=-1 / 512,
num_chunks=8000,
mc_algo='mc',
output_type: Optional[str] = "trimesh",
enable_pbar=True,
**kwargs,
) -> List[List[trimesh.Trimesh]]:
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
# @torch.no_grad()
# def __call__(
# self,
# image: Union[str, List[str], Image.Image] = None,
# num_inference_steps: int = 50,
# timesteps: List[int] = None,
# sigmas: List[float] = None,
# eta: float = 0.0,
# guidance_scale: float = 7.5,
# dual_guidance_scale: float = 10.5,
# dual_guidance: bool = True,
# generator=None,
# box_v=1.01,
# octree_resolution=384,
# mc_level=-1 / 512,
# num_chunks=8000,
# mc_algo='mc',
# output_type: Optional[str] = "trimesh",
# enable_pbar=True,
# **kwargs,
# ) -> List[List[trimesh.Trimesh]]:
# callback = kwargs.pop("callback", None)
# callback_steps = kwargs.pop("callback_steps", None)
device = self.main_device
dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and \
getattr(self.model, 'guidance_cond_proj_dim', None) is None
dual_guidance = dual_guidance_scale >= 0 and dual_guidance
# device = self.main_device
# dtype = self.dtype
# do_classifier_free_guidance = guidance_scale >= 0 and \
# getattr(self.model, 'guidance_cond_proj_dim', None) is None
# dual_guidance = dual_guidance_scale >= 0 and dual_guidance
image, mask = self.prepare_image(image)
cond = self.encode_cond(image=image,
mask=mask,
do_classifier_free_guidance=do_classifier_free_guidance,
dual_guidance=dual_guidance)
batch_size = image.shape[0]
# image, mask = self.prepare_image(image)
# cond = self.encode_cond(image=image,
# mask=mask,
# do_classifier_free_guidance=do_classifier_free_guidance,
# dual_guidance=dual_guidance)
# batch_size = image.shape[0]
t_dtype = torch.long
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas)
# t_dtype = torch.long
# timesteps, num_inference_steps = retrieve_timesteps(
# self.scheduler, num_inference_steps, device, timesteps, sigmas)
latents = self.prepare_latents(batch_size, dtype, device, generator)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# latents = self.prepare_latents(batch_size, dtype, device, generator)
# extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
guidance_cond = None
if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
print('Using lcm guidance scale')
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
guidance_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# guidance_cond = None
# if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
# print('Using lcm guidance scale')
# guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
# guidance_cond = self.get_guidance_scale_embedding(
# guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
# ).to(device=device, dtype=latents.dtype)
comfy_pbar = ProgressBar(num_inference_steps)
# comfy_pbar = ProgressBar(num_inference_steps)
self.model.to(device)
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
else:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# self.model.to(device)
# for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
# # expand the latents if we are doing classifier free guidance
# if do_classifier_free_guidance:
# latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
# else:
# latent_model_input = latents
# latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
# # predict the noise residual
# timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
# timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
# noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
# no drop, drop clip, all drop
if do_classifier_free_guidance:
if dual_guidance:
noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_clip - noise_pred_dino)
+ dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
)
else:
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# # no drop, drop clip, all drop
# if do_classifier_free_guidance:
# if dual_guidance:
# noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
# noise_pred = (
# noise_pred_uncond
# + guidance_scale * (noise_pred_clip - noise_pred_dino)
# + dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
# )
# else:
# noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
latents = outputs.prev_sample
# # compute the previous noisy sample x_t -> x_t-1
# outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
# latents = outputs.prev_sample
comfy_pbar.update(1)
# comfy_pbar.update(1)
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs)
self.model.to(self.offload_device)
mm.soft_empty_cache()
# if callback is not None and i % callback_steps == 0:
# step_idx = i // getattr(self.scheduler, "order", 1)
# callback(step_idx, t, outputs)
# self.model.to(self.offload_device)
# mm.soft_empty_cache()
return self._export(
latents,
output_type,
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
)
# return self._export(
# latents,
# output_type,
# box_v, mc_level, num_chunks, octree_resolution, mc_algo,
# )
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
if not output_type == "latent":
self.vae.to(self.main_device)
latents = 1. / self.vae.scale_factor * latents
latents = self.vae(latents)
outputs = self.vae.latents2mesh(
latents,
bounds=box_v,
mc_level=mc_level,
num_chunks=num_chunks,
octree_resolution=octree_resolution,
mc_algo=mc_algo,
)
self.vae.to(self.offload_device)
else:
outputs = latents
# def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
# if not output_type == "latent":
# self.vae.to(self.main_device)
# latents = 1. / self.vae.scale_factor * latents
# latents = self.vae(latents)
# outputs = self.vae.latents2mesh(
# latents,
# bounds=box_v,
# mc_level=mc_level,
# num_chunks=num_chunks,
# octree_resolution=octree_resolution,
# mc_algo=mc_algo,
# )
# self.vae.to(self.offload_device)
# else:
# outputs = latents
if output_type == 'trimesh':
outputs = export_to_trimesh(outputs)
# if output_type == 'trimesh':
# outputs = export_to_trimesh(outputs)
return outputs
# return outputs
class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
@ -555,15 +556,15 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
eta: float = 0.0,
#eta: float = 0.0,
guidance_scale: float = 7.5,
generator=None,
box_v=1.01,
octree_resolution=384,
mc_level=0.0,
mc_algo='mc',
num_chunks=8000,
output_type: Optional[str] = "trimesh",
# box_v=1.01,
# octree_resolution=384,
# mc_level=0.0,
# mc_algo='mc',
# num_chunks=8000,
# output_type: Optional[str] = "trimesh",
enable_pbar=True,
**kwargs,
) -> List[List[trimesh.Trimesh]]:
@ -628,9 +629,10 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs)
comfy_pbar.update(1)
return self._export(
latents,
output_type,
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
)
print("latents shape: ", latents.shape)
return latents
# return self._export(
# latents,
# output_type,
# box_v, mc_level, num_chunks, octree_resolution, mc_algo,
# )

View File

@ -77,8 +77,8 @@ class Hy3DModelLoader:
}
}
RETURN_TYPES = ("HY3DMODEL",)
RETURN_NAMES = ("pipeline", )
RETURN_TYPES = ("HY3DMODEL", "HY3DVAE")
RETURN_NAMES = ("pipeline", "vae")
FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper"
@ -88,7 +88,7 @@ class Hy3DModelLoader:
config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
model_path = folder_paths.get_full_path("diffusion_models", model)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
pipe, vae = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
ckpt_path=model_path,
config_path=config_path,
use_safetensors=True,
@ -97,7 +97,7 @@ class Hy3DModelLoader:
compile_args=compile_args,
attention_mode=attention_mode)
return (pipe,)
return (pipe, vae,)
class DownloadAndLoadHy3DDelightModel:
@classmethod
@ -676,7 +676,6 @@ class Hy3DGenerateMesh:
"required": {
"pipeline": ("HY3DMODEL",),
"image": ("IMAGE", ),
"octree_resolution": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 16}),
"guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}),
"steps": ("INT", {"default": 30, "min": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
@ -686,12 +685,12 @@ class Hy3DGenerateMesh:
}
}
RETURN_TYPES = ("HY3DMESH",)
RETURN_NAMES = ("mesh",)
RETURN_TYPES = ("HY3DLATENT",)
RETURN_NAMES = ("latents",)
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None):
def process(self, pipeline, image, steps, guidance_scale, seed, mask=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -709,16 +708,12 @@ class Hy3DGenerateMesh:
except:
pass
mesh = pipeline(
latents = pipeline(
image=image,
mask=mask,
num_inference_steps=steps,
mc_algo='mc',
guidance_scale=guidance_scale,
octree_resolution=octree_resolution,
generator=torch.manual_seed(seed))[0]
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
generator=torch.manual_seed(seed))
print_memory(device)
try:
@ -728,7 +723,51 @@ class Hy3DGenerateMesh:
pipeline.to(offload_device)
return (mesh, )
return (latents, )
class Hy3DVAEDecode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae": ("HY3DVAE",),
"latents": ("HY3DLATENT", ),
"box_v": ("FLOAT", {"default": 1.01, "min": -10.0, "max": 10.0, "step": 0.001}),
"octree_resolution": ("INT", {"default": 384, "min": 64, "max": 4096, "step": 16}),
"num_chunks": ("INT", {"default": 8000, "min": 1, "max": 10000000, "step": 1}),
"mc_level": ("FLOAT", {"default": 0, "min": -1.0, "max": 1.0, "step": 0.0001}),
"mc_algo": (["mc", "dmc"], {"default": "mc"}),
},
}
RETURN_TYPES = ("HY3DMESH",)
RETURN_NAMES = ("mesh",)
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
vae.to(device)
latents = 1. / vae.scale_factor * latents
latents = vae(latents)
outputs = vae.latents2mesh(
latents,
bounds=box_v,
mc_level=mc_level,
num_chunks=num_chunks,
octree_resolution=octree_resolution,
mc_algo=mc_algo,
)[0]
vae.to(offload_device)
outputs.mesh_f = outputs.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(outputs.mesh_v, outputs.mesh_f)
log.info(f"Decoded mesh with {mesh_output.vertices.shape[0]} vertices and {mesh_output.faces.shape[0]} faces")
return (mesh_output, )
class Hy3DPostprocessMesh:
@classmethod
@ -918,7 +957,8 @@ NODE_CLASS_MAPPINGS = {
"Hy3DRenderMultiViewDepth": Hy3DRenderMultiViewDepth,
"Hy3DGetMeshPBRTextures": Hy3DGetMeshPBRTextures,
"Hy3DSetMeshPBRTextures": Hy3DSetMeshPBRTextures,
"Hy3DSetMeshPBRAttributes": Hy3DSetMeshPBRAttributes
"Hy3DSetMeshPBRAttributes": Hy3DSetMeshPBRAttributes,
"Hy3DVAEDecode": Hy3DVAEDecode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader",
@ -941,5 +981,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DRenderMultiViewDepth": "Hy3D Render MultiView Depth",
"Hy3DGetMeshPBRTextures": "Hy3D Get Mesh PBR Textures",
"Hy3DSetMeshPBRTextures": "Hy3D Set Mesh PBR Textures",
"Hy3DSetMeshPBRAttributes": "Hy3D Set Mesh PBR Attributes"
"Hy3DSetMeshPBRAttributes": "Hy3D Set Mesh PBR Attributes",
"Hy3DVAEDecode": "Hy3D VAE Decode"
}