diff --git a/nodes.py b/nodes.py index a6e2de8..8818ace 100644 --- a/nodes.py +++ b/nodes.py @@ -571,6 +571,9 @@ class MochiDecode: "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), }, + "optional": { + "unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}), + }, } RETURN_TYPES = ("IMAGE",) @@ -579,12 +582,13 @@ class MochiDecode: CATEGORY = "MochiWrapper" def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, - tile_overlap_factor_width, auto_tile_size, frame_batch_size): + tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False): device = mm.get_torch_device() offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] - samples = dit_latents_to_vae_latents(samples) + if unnormalize: + samples = dit_latents_to_vae_latents(samples) samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape @@ -699,6 +703,9 @@ class MochiDecodeSpatialTiling: "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), "per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}), }, + "optional": { + "unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}), + }, } RETURN_TYPES = ("IMAGE",) @@ -707,12 +714,13 @@ class MochiDecodeSpatialTiling: CATEGORY = "MochiWrapper" def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, - min_block_size, per_batch): + min_block_size, per_batch, unnormalize=True): device = mm.get_torch_device() offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] - samples = dit_latents_to_vae_latents(samples) + if unnormalize: + samples = dit_latents_to_vae_latents(samples) samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape @@ -768,14 +776,17 @@ class MochiImageEncode: "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), }, + "optional": { + "normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}), + }, } RETURN_TYPES = ("LATENT",) RETURN_NAMES = ("samples",) - FUNCTION = "decode" + FUNCTION = "encode" CATEGORY = "MochiWrapper" - def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size): + def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True): device = mm.get_torch_device() offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() @@ -800,7 +811,8 @@ class MochiImageEncode: latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) else: latents = encoder(video) - latents = vae_latents_to_dit_latents(latents) + if normalize: + latents = vae_latents_to_dit_latents(latents) print("encoder output",latents.shape) return ({"samples": latents},)