From a5b06b02ad89f1e854ca5eced8038b829dbb88bc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:03:47 +0200 Subject: [PATCH] tiled encoding --- mochi_preview/t2v_synth_mochi.py | 1 - mochi_preview/vae/model.py | 5 +++++ nodes.py | 33 ++++++++++++++++++++++---------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index ec2f0c8..8ea61bb 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -323,6 +323,5 @@ class T2VSynthMochiModel: comfy_pbar.update(1) self.dit.to(self.offload_device) - #samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) logging.info(f"samples shape: {z.shape}") return z diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 4085623..6cfeeae 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -899,6 +899,11 @@ class Encoder(nn.Module): assert logvar.shape == means.shape assert means.size(1) == self.latent_dim + noise = torch.randn(means.shape, device=means.device, dtype=means.dtype, generator=None) + + # Just Gaussian sample with no scaling of variance. + return noise * torch.exp(logvar * 0.5) + means + return LatentDistribution(means, logvar) diff --git a/nodes.py b/nodes.py index 9104fb2..c5d9eaa 100644 --- a/nodes.py +++ b/nodes.py @@ -43,7 +43,8 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] sigma_schedule = [1.0 - x for x in sigma_schedule] return sigma_schedule - + +#region ModelLoading class DownloadAndLoadMochiModel: @classmethod def INPUT_TYPES(s): @@ -358,7 +359,8 @@ class MochiVAEEncoderLoader: encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"]) return (encoder,) - +#endregion + class MochiTextEncode: @classmethod def INPUT_TYPES(s): @@ -412,7 +414,7 @@ class MochiTextEncode: } return (t5_embeds, clip,) - +#region Sampler class MochiSampler: @classmethod def INPUT_TYPES(s): @@ -427,7 +429,6 @@ class MochiSampler: "steps": ("INT", {"default": 50, "min": 2}), "cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - #"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}), }, "optional": { "cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}), @@ -489,7 +490,9 @@ class MochiSampler: mm.soft_empty_cache() return ({"samples": latents},) - +#endregion +#region Latents + class MochiDecode: @classmethod def INPUT_TYPES(s): @@ -688,12 +691,18 @@ class MochiDecodeSpatialTiling: return (frames,) + class MochiImageEncode: @classmethod def INPUT_TYPES(s): return {"required": { "encoder": ("MOCHIVAE",), "images": ("IMAGE", ), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + "num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}), + "num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}), + "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), + "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), }, } @@ -702,23 +711,25 @@ class MochiImageEncode: FUNCTION = "decode" CATEGORY = "MochiWrapper" - def decode(self, encoder, images): + def decode(self, encoder, images, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, min_block_size): device = mm.get_torch_device() offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() - + from .mochi_preview.vae.model import apply_tiled B, H, W, C = images.shape images = images.unsqueeze(0) * 2 - 1 images = rearrange(images, "t b h w c -> t c b h w") images = images.to(encoder.dtype).to(device) print(images.shape) - encoder.to(device) print("images before encoding", images.shape) with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype): - video = add_fourier_features(images) - latents = encoder(video).sample() + video = add_fourier_features(images) + if enable_vae_tiling: + latents = apply_tiled(encoder, video, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) + else: + latents = encoder(video) latents = vae_latents_to_dit_latents(latents) print("encoder output",latents.shape) @@ -784,6 +795,8 @@ class MochiLatentPreview: return (latent_images.float().cpu(),) +#endregion +#region NodeMappings NODE_CLASS_MAPPINGS = { "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, "MochiSampler": MochiSampler,