make latent normalization optional for testing
This commit is contained in:
parent
fbd2252dc4
commit
84e536e226
26
nodes.py
26
nodes.py
@ -571,6 +571,9 @@ class MochiDecode:
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"unnormalize": ("BOOLEAN", {"default": False, "tooltip": "Unnormalize the latents before decoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -579,12 +582,13 @@ class MochiDecode:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height,
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size, unnormalize=False):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
if unnormalize:
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
@ -699,6 +703,9 @@ class MochiDecodeSpatialTiling:
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
"per_batch": ("INT", {"default": 6, "min": 1, "max": 256, "step": 1, "tooltip": "Number of samples per batch, in latent space (6 frames in 1 latent)"}),
|
||||
},
|
||||
"optional": {
|
||||
"unnormalize": ("BOOLEAN", {"default": True, "tooltip": "Unnormalize the latents before decoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -707,12 +714,13 @@ class MochiDecodeSpatialTiling:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||
min_block_size, per_batch):
|
||||
min_block_size, per_batch, unnormalize=True):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
if unnormalize:
|
||||
samples = dit_latents_to_vae_latents(samples)
|
||||
samples = samples.to(vae.dtype).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
@ -768,14 +776,17 @@ class MochiImageEncode:
|
||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
},
|
||||
"optional": {
|
||||
"normalize": ("BOOLEAN", {"default": True, "tooltip": "Normalize the images before encoding"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_NAMES = ("samples",)
|
||||
FUNCTION = "decode"
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size):
|
||||
def encode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size, normalize=True):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
@ -800,7 +811,8 @@ class MochiImageEncode:
|
||||
latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
else:
|
||||
latents = encoder(video)
|
||||
latents = vae_latents_to_dit_latents(latents)
|
||||
if normalize:
|
||||
latents = vae_latents_to_dit_latents(latents)
|
||||
print("encoder output",latents.shape)
|
||||
|
||||
return ({"samples": latents},)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user