From 36a4275b3b8121ccf65df534fbd53f313757377a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:30:20 +0300 Subject: [PATCH] Add alternative VAE decoding node This was actually unused code in the VAE model, only does spatial tiling though, but seams look better --- fp8_optimization.py | 3 +- .../vae/__pycache__/model.cpython-312.pyc | Bin 34153 -> 34181 bytes nodes.py | 52 ++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/fp8_optimization.py b/fp8_optimization.py index c06913b..3d9685a 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -30,7 +30,7 @@ def fp8_linear_forward(cls, original_dtype, input): if isinstance(o, tuple): o = o[0] - + if tensor_2d: return o.reshape(input.shape[0], -1) @@ -38,7 +38,6 @@ def fp8_linear_forward(cls, original_dtype, input): else: cls.to(original_dtype) out = cls.original_forward(input.to(original_dtype)) - cls.to(original_dtype) return out else: return cls.original_forward(input) diff --git a/mochi_preview/vae/__pycache__/model.cpython-312.pyc b/mochi_preview/vae/__pycache__/model.cpython-312.pyc index b4fb4a5d499b584fc0adf59719f5aa152918993b..b448eecb5d46f291c54576da3e56535b05d9296b 100644 GIT binary patch delta 76 zcmaFa#njr(#C@8Vmx}=ie)mal#80Hx&_&;S4c delta 48 zcmZqeW_sDh#C@8Vmx}=iKJZFxtAj4vokEz3+T-)zg~Q4at~ Cj1H9m diff --git a/nodes.py b/nodes.py index 2cd1b3d..64ad2b1 100644 --- a/nodes.py +++ b/nodes.py @@ -464,6 +464,56 @@ class MochiDecode: return (frames,) +class MochiDecodeSpatialTiling: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("MOCHIVAE",), + "samples": ("LATENT", ), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + "num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}), + "num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}), + "overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}), + "min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "decode" + CATEGORY = "MochiWrapper" + + def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap, + min_block_size): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + intermediate_device = mm.intermediate_device() + samples = samples["samples"] + samples = samples.to(torch.bfloat16).to(device) + + B, C, T, H, W = samples.shape + + vae.to(device) + + with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + if enable_vae_tiling: + from .mochi_preview.vae.model import apply_tiled + logging.warning("Decoding with tiling...") + frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) + else: + logging.info("Decoding without tiling...") + frames = vae(samples) + + vae.to(offload_device) + + frames = frames.float() + frames = (frames + 1.0) / 2.0 + frames.clamp_(0.0, 1.0) + + frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device) + + return (frames,) + NODE_CLASS_MAPPINGS = { "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, @@ -472,6 +522,7 @@ NODE_CLASS_MAPPINGS = { "MochiTextEncode": MochiTextEncode, "MochiModelLoader": MochiModelLoader, "MochiVAELoader": MochiVAELoader, + "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", @@ -480,4 +531,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MochiTextEncode": "Mochi TextEncode", "MochiModelLoader": "Mochi Model Loader", "MochiVAELoader": "Mochi VAE Loader", + "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling" }