make latent normalization optional for testing

This commit is contained in:
kijai 2024-11-05 22:30:13 +02:00
parent fbd2252dc4
commit 84e536e226

View File

@ -571,6 +571,9 @@ class MochiDecode:
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
}, },
"optional": {
"unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}),
},
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -579,12 +582,13 @@ class MochiDecode:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
tile_overlap_factor_width, auto_tile_size, frame_batch_size): tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples) if unnormalize:
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -699,6 +703,9 @@ class MochiDecodeSpatialTiling:
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}), "per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
}, },
"optional": {
"unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}),
},
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -707,12 +714,13 @@ class MochiDecodeSpatialTiling:
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
min_block_size, per_batch): min_block_size, per_batch, unnormalize=True):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples) if unnormalize:
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -768,14 +776,17 @@ class MochiImageEncode:
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
}, },
"optional": {
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}),
},
} }
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",) RETURN_NAMES = ("samples",)
FUNCTION = "decode" FUNCTION = "encode"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size): def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
@ -800,7 +811,8 @@ class MochiImageEncode:
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
else: else:
latents = encoder(video) latents = encoder(video)
latents = vae_latents_to_dit_latents(latents) if normalize:
latents = vae_latents_to_dit_latents(latents)
print("encoder output",latents.shape) print("encoder output",latents.shape)
return ({"samples": latents},) return ({"samples": latents},)