Separate VAE (old workflows need updating)

This commit is contained in:
kijai 2025-01-26 02:48:37 +02:00
parent 7e234a017e
commit d563280607
4 changed files with 855 additions and 735 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 MiB

After

Width:  |  Height:  |  Size: 6.9 MiB

View File

@ -212,7 +212,7 @@ class Hunyuan3DDiTPipeline:
vae = torch.compile(vae) vae = torch.compile(vae)
model_kwargs = dict( model_kwargs = dict(
vae=vae, #vae=vae,
model=model, model=model,
scheduler=scheduler, scheduler=scheduler,
conditioner=conditioner, conditioner=conditioner,
@ -223,56 +223,54 @@ class Hunyuan3DDiTPipeline:
) )
model_kwargs.update(kwargs) model_kwargs.update(kwargs)
return cls( return cls(**model_kwargs), vae
**model_kwargs
)
@classmethod # @classmethod
def from_pretrained( # def from_pretrained(
cls, # cls,
model_path, # model_path,
ckpt_name='model.ckpt', # ckpt_name='model.ckpt',
config_name='config.yaml', # config_name='config.yaml',
device='cuda', # device='cuda',
dtype=torch.float16, # dtype=torch.float16,
use_safetensors=None, # use_safetensors=None,
**kwargs, # **kwargs,
): # ):
original_model_path = model_path # original_model_path = model_path
if not os.path.exists(model_path): # if not os.path.exists(model_path):
# try local path # # try local path
base_dir = "checkpoints" # base_dir = "checkpoints"
model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0') # model_path = os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0')
if not os.path.exists(model_path): # if not os.path.exists(model_path):
try: # try:
import huggingface_hub # import huggingface_hub
# download from huggingface # # download from huggingface
huggingface_hub.snapshot_download( # huggingface_hub.snapshot_download(
repo_id="tencent/Hunyuan3D-2", # repo_id="tencent/Hunyuan3D-2",
local_dir=base_dir,) # local_dir=base_dir,)
except ImportError: # except ImportError:
logger.warning( # logger.warning(
"You need to install HuggingFace Hub to load models from the hub." # "You need to install HuggingFace Hub to load models from the hub."
) # )
raise RuntimeError(f"Model path {model_path} not found") # raise RuntimeError(f"Model path {model_path} not found")
if not os.path.exists(model_path): # if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path {original_model_path} not found") # raise FileNotFoundError(f"Model path {original_model_path} not found")
config_path = os.path.join(model_path, config_name) # config_path = os.path.join(model_path, config_name)
ckpt_path = os.path.join(model_path, ckpt_name) # ckpt_path = os.path.join(model_path, ckpt_name)
return cls.from_single_file( # return cls.from_single_file(
ckpt_path, # ckpt_path,
config_path, # config_path,
device=device, # device=device,
dtype=dtype, # dtype=dtype,
use_safetensors=use_safetensors, # use_safetensors=use_safetensors,
**kwargs # **kwargs
) # )
def __init__( def __init__(
self, self,
vae, #vae,
model, model,
scheduler, scheduler,
conditioner, conditioner,
@ -282,7 +280,7 @@ class Hunyuan3DDiTPipeline:
dtype=torch.float16, dtype=torch.float16,
**kwargs **kwargs
): ):
self.vae = vae #self.vae = vae
self.model = model self.model = model
self.scheduler = scheduler self.scheduler = scheduler
self.conditioner = conditioner self.conditioner = conditioner
@ -295,12 +293,12 @@ class Hunyuan3DDiTPipeline:
def to(self, device=None, dtype=None): def to(self, device=None, dtype=None):
if device is not None: if device is not None:
self.vae.to(device) #self.vae.to(device)
self.model.to(device) self.model.to(device)
self.conditioner.to(device) self.conditioner.to(device)
if dtype is not None: if dtype is not None:
self.dtype = dtype self.dtype = dtype
self.vae.to(dtype=dtype) #self.vae.to(dtype=dtype)
self.model.to(dtype=dtype) self.model.to(dtype=dtype)
self.conditioner.to(dtype=dtype) self.conditioner.to(dtype=dtype)
@ -359,7 +357,10 @@ class Hunyuan3DDiTPipeline:
return extra_step_kwargs return extra_step_kwargs
def prepare_latents(self, batch_size, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, dtype, device, generator, latents=None):
shape = (batch_size, *self.vae.latent_shape) #shape = (batch_size, *self.vae.latent_shape)
num_latents = 3072
embed_dim = 64
shape = (batch_size, num_latents, embed_dim)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@ -423,126 +424,126 @@ class Hunyuan3DDiTPipeline:
assert emb.shape == (w.shape[0], embedding_dim) assert emb.shape == (w.shape[0], embedding_dim)
return emb return emb
@torch.no_grad() # @torch.no_grad()
def __call__( # def __call__(
self, # self,
image: Union[str, List[str], Image.Image] = None, # image: Union[str, List[str], Image.Image] = None,
num_inference_steps: int = 50, # num_inference_steps: int = 50,
timesteps: List[int] = None, # timesteps: List[int] = None,
sigmas: List[float] = None, # sigmas: List[float] = None,
eta: float = 0.0, # eta: float = 0.0,
guidance_scale: float = 7.5, # guidance_scale: float = 7.5,
dual_guidance_scale: float = 10.5, # dual_guidance_scale: float = 10.5,
dual_guidance: bool = True, # dual_guidance: bool = True,
generator=None, # generator=None,
box_v=1.01, # box_v=1.01,
octree_resolution=384, # octree_resolution=384,
mc_level=-1 / 512, # mc_level=-1 / 512,
num_chunks=8000, # num_chunks=8000,
mc_algo='mc', # mc_algo='mc',
output_type: Optional[str] = "trimesh", # output_type: Optional[str] = "trimesh",
enable_pbar=True, # enable_pbar=True,
**kwargs, # **kwargs,
) -> List[List[trimesh.Trimesh]]: # ) -> List[List[trimesh.Trimesh]]:
callback = kwargs.pop("callback", None) # callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None) # callback_steps = kwargs.pop("callback_steps", None)
device = self.main_device # device = self.main_device
dtype = self.dtype # dtype = self.dtype
do_classifier_free_guidance = guidance_scale >= 0 and \ # do_classifier_free_guidance = guidance_scale >= 0 and \
getattr(self.model, 'guidance_cond_proj_dim', None) is None # getattr(self.model, 'guidance_cond_proj_dim', None) is None
dual_guidance = dual_guidance_scale >= 0 and dual_guidance # dual_guidance = dual_guidance_scale >= 0 and dual_guidance
image, mask = self.prepare_image(image) # image, mask = self.prepare_image(image)
cond = self.encode_cond(image=image, # cond = self.encode_cond(image=image,
mask=mask, # mask=mask,
do_classifier_free_guidance=do_classifier_free_guidance, # do_classifier_free_guidance=do_classifier_free_guidance,
dual_guidance=dual_guidance) # dual_guidance=dual_guidance)
batch_size = image.shape[0] # batch_size = image.shape[0]
t_dtype = torch.long # t_dtype = torch.long
timesteps, num_inference_steps = retrieve_timesteps( # timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas) # self.scheduler, num_inference_steps, device, timesteps, sigmas)
latents = self.prepare_latents(batch_size, dtype, device, generator) # latents = self.prepare_latents(batch_size, dtype, device, generator)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
guidance_cond = None # guidance_cond = None
if getattr(self.model, 'guidance_cond_proj_dim', None) is not None: # if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
print('Using lcm guidance scale') # print('Using lcm guidance scale')
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size) # guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
guidance_cond = self.get_guidance_scale_embedding( # guidance_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim # guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
).to(device=device, dtype=latents.dtype) # ).to(device=device, dtype=latents.dtype)
comfy_pbar = ProgressBar(num_inference_steps) # comfy_pbar = ProgressBar(num_inference_steps)
self.model.to(device) # self.model.to(device)
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)): # for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance # # expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance: # if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2)) # latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
else: # else:
latent_model_input = latents # latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # # predict the noise residual
timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device) # timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) # timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond) # noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
# no drop, drop clip, all drop # # no drop, drop clip, all drop
if do_classifier_free_guidance: # if do_classifier_free_guidance:
if dual_guidance: # if dual_guidance:
noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3) # noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = ( # noise_pred = (
noise_pred_uncond # noise_pred_uncond
+ guidance_scale * (noise_pred_clip - noise_pred_dino) # + guidance_scale * (noise_pred_clip - noise_pred_dino)
+ dual_guidance_scale * (noise_pred_dino - noise_pred_uncond) # + dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
) # )
else: # else:
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) # noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # # compute the previous noisy sample x_t -> x_t-1
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs) # outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
latents = outputs.prev_sample # latents = outputs.prev_sample
comfy_pbar.update(1) # comfy_pbar.update(1)
if callback is not None and i % callback_steps == 0: # if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) # step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs) # callback(step_idx, t, outputs)
self.model.to(self.offload_device) # self.model.to(self.offload_device)
mm.soft_empty_cache() # mm.soft_empty_cache()
return self._export( # return self._export(
latents, # latents,
output_type, # output_type,
box_v, mc_level, num_chunks, octree_resolution, mc_algo, # box_v, mc_level, num_chunks, octree_resolution, mc_algo,
) # )
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo): # def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
if not output_type == "latent": # if not output_type == "latent":
self.vae.to(self.main_device) # self.vae.to(self.main_device)
latents = 1. / self.vae.scale_factor * latents # latents = 1. / self.vae.scale_factor * latents
latents = self.vae(latents) # latents = self.vae(latents)
outputs = self.vae.latents2mesh( # outputs = self.vae.latents2mesh(
latents, # latents,
bounds=box_v, # bounds=box_v,
mc_level=mc_level, # mc_level=mc_level,
num_chunks=num_chunks, # num_chunks=num_chunks,
octree_resolution=octree_resolution, # octree_resolution=octree_resolution,
mc_algo=mc_algo, # mc_algo=mc_algo,
) # )
self.vae.to(self.offload_device) # self.vae.to(self.offload_device)
else: # else:
outputs = latents # outputs = latents
if output_type == 'trimesh': # if output_type == 'trimesh':
outputs = export_to_trimesh(outputs) # outputs = export_to_trimesh(outputs)
return outputs # return outputs
class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline): class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
@ -555,15 +556,15 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None, sigmas: List[float] = None,
eta: float = 0.0, #eta: float = 0.0,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
generator=None, generator=None,
box_v=1.01, # box_v=1.01,
octree_resolution=384, # octree_resolution=384,
mc_level=0.0, # mc_level=0.0,
mc_algo='mc', # mc_algo='mc',
num_chunks=8000, # num_chunks=8000,
output_type: Optional[str] = "trimesh", # output_type: Optional[str] = "trimesh",
enable_pbar=True, enable_pbar=True,
**kwargs, **kwargs,
) -> List[List[trimesh.Trimesh]]: ) -> List[List[trimesh.Trimesh]]:
@ -628,9 +629,10 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, outputs) callback(step_idx, t, outputs)
comfy_pbar.update(1) comfy_pbar.update(1)
print("latents shape: ", latents.shape)
return self._export( return latents
latents, # return self._export(
output_type, # latents,
box_v, mc_level, num_chunks, octree_resolution, mc_algo, # output_type,
) # box_v, mc_level, num_chunks, octree_resolution, mc_algo,
# )

View File

@ -77,8 +77,8 @@ class Hy3DModelLoader:
} }
} }
RETURN_TYPES = ("HY3DMODEL",) RETURN_TYPES = ("HY3DMODEL", "HY3DVAE")
RETURN_NAMES = ("pipeline", ) RETURN_NAMES = ("pipeline", "vae")
FUNCTION = "loadmodel" FUNCTION = "loadmodel"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
@ -88,7 +88,7 @@ class Hy3DModelLoader:
config_path = os.path.join(script_directory, "configs", "dit_config.yaml") config_path = os.path.join(script_directory, "configs", "dit_config.yaml")
model_path = folder_paths.get_full_path("diffusion_models", model) model_path = folder_paths.get_full_path("diffusion_models", model)
pipe = Hunyuan3DDiTFlowMatchingPipeline.from_single_file( pipe, vae = Hunyuan3DDiTFlowMatchingPipeline.from_single_file(
ckpt_path=model_path, ckpt_path=model_path,
config_path=config_path, config_path=config_path,
use_safetensors=True, use_safetensors=True,
@ -97,7 +97,7 @@ class Hy3DModelLoader:
compile_args=compile_args, compile_args=compile_args,
attention_mode=attention_mode) attention_mode=attention_mode)
return (pipe,) return (pipe, vae,)
class DownloadAndLoadHy3DDelightModel: class DownloadAndLoadHy3DDelightModel:
@classmethod @classmethod
@ -676,7 +676,6 @@ class Hy3DGenerateMesh:
"required": { "required": {
"pipeline": ("HY3DMODEL",), "pipeline": ("HY3DMODEL",),
"image": ("IMAGE", ), "image": ("IMAGE", ),
"octree_resolution": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 16}),
"guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}), "guidance_scale": ("FLOAT", {"default": 5.5, "min": 0.0, "max": 100.0, "step": 0.01}),
"steps": ("INT", {"default": 30, "min": 1}), "steps": ("INT", {"default": 30, "min": 1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
@ -686,12 +685,12 @@ class Hy3DGenerateMesh:
} }
} }
RETURN_TYPES = ("HY3DMESH",) RETURN_TYPES = ("HY3DLATENT",)
RETURN_NAMES = ("mesh",) RETURN_NAMES = ("latents",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper" CATEGORY = "Hunyuan3DWrapper"
def process(self, pipeline, image, steps, guidance_scale, octree_resolution, seed, mask=None): def process(self, pipeline, image, steps, guidance_scale, seed, mask=None):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -709,16 +708,12 @@ class Hy3DGenerateMesh:
except: except:
pass pass
mesh = pipeline( latents = pipeline(
image=image, image=image,
mask=mask, mask=mask,
num_inference_steps=steps, num_inference_steps=steps,
mc_algo='mc',
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
octree_resolution=octree_resolution, generator=torch.manual_seed(seed))
generator=torch.manual_seed(seed))[0]
log.info(f"Generated mesh with {mesh.vertices.shape[0]} vertices and {mesh.faces.shape[0]} faces")
print_memory(device) print_memory(device)
try: try:
@ -728,7 +723,51 @@ class Hy3DGenerateMesh:
pipeline.to(offload_device) pipeline.to(offload_device)
return (mesh, ) return (latents, )
class Hy3DVAEDecode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae": ("HY3DVAE",),
"latents": ("HY3DLATENT", ),
"box_v": ("FLOAT", {"default": 1.01, "min": -10.0, "max": 10.0, "step": 0.001}),
"octree_resolution": ("INT", {"default": 384, "min": 64, "max": 4096, "step": 16}),
"num_chunks": ("INT", {"default": 8000, "min": 1, "max": 10000000, "step": 1}),
"mc_level": ("FLOAT", {"default": 0, "min": -1.0, "max": 1.0, "step": 0.0001}),
"mc_algo": (["mc", "dmc"], {"default": "mc"}),
},
}
RETURN_TYPES = ("HY3DMESH",)
RETURN_NAMES = ("mesh",)
FUNCTION = "process"
CATEGORY = "Hunyuan3DWrapper"
def process(self, vae, latents, box_v, octree_resolution, mc_level, num_chunks, mc_algo):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
vae.to(device)
latents = 1. / vae.scale_factor * latents
latents = vae(latents)
outputs = vae.latents2mesh(
latents,
bounds=box_v,
mc_level=mc_level,
num_chunks=num_chunks,
octree_resolution=octree_resolution,
mc_algo=mc_algo,
)[0]
vae.to(offload_device)
outputs.mesh_f = outputs.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(outputs.mesh_v, outputs.mesh_f)
log.info(f"Decoded mesh with {mesh_output.vertices.shape[0]} vertices and {mesh_output.faces.shape[0]} faces")
return (mesh_output, )
class Hy3DPostprocessMesh: class Hy3DPostprocessMesh:
@classmethod @classmethod
@ -918,7 +957,8 @@ NODE_CLASS_MAPPINGS = {
"Hy3DRenderMultiViewDepth": Hy3DRenderMultiViewDepth, "Hy3DRenderMultiViewDepth": Hy3DRenderMultiViewDepth,
"Hy3DGetMeshPBRTextures": Hy3DGetMeshPBRTextures, "Hy3DGetMeshPBRTextures": Hy3DGetMeshPBRTextures,
"Hy3DSetMeshPBRTextures": Hy3DSetMeshPBRTextures, "Hy3DSetMeshPBRTextures": Hy3DSetMeshPBRTextures,
"Hy3DSetMeshPBRAttributes": Hy3DSetMeshPBRAttributes "Hy3DSetMeshPBRAttributes": Hy3DSetMeshPBRAttributes,
"Hy3DVAEDecode": Hy3DVAEDecode
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DModelLoader": "Hy3DModelLoader", "Hy3DModelLoader": "Hy3DModelLoader",
@ -941,5 +981,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"Hy3DRenderMultiViewDepth": "Hy3D Render MultiView Depth", "Hy3DRenderMultiViewDepth": "Hy3D Render MultiView Depth",
"Hy3DGetMeshPBRTextures": "Hy3D Get Mesh PBR Textures", "Hy3DGetMeshPBRTextures": "Hy3D Get Mesh PBR Textures",
"Hy3DSetMeshPBRTextures": "Hy3D Set Mesh PBR Textures", "Hy3DSetMeshPBRTextures": "Hy3D Set Mesh PBR Textures",
"Hy3DSetMeshPBRAttributes": "Hy3D Set Mesh PBR Attributes" "Hy3DSetMeshPBRAttributes": "Hy3D Set Mesh PBR Attributes",
"Hy3DVAEDecode": "Hy3D VAE Decode"
} }