Add alternative VAE decoding node

This was actually unused code in the VAE model, only does spatial tiling though, but seams look better
This commit is contained in:
kijai 2024-10-25 15:30:20 +03:00
parent 4c72322e8b
commit 36a4275b3b
3 changed files with 53 additions and 2 deletions

View File

@ -30,7 +30,7 @@ def fp8_linear_forward(cls, original_dtype, input):
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
@ -38,7 +38,6 @@ def fp8_linear_forward(cls, original_dtype, input):
else:
cls.to(original_dtype)
out = cls.original_forward(input.to(original_dtype))
cls.to(original_dtype)
return out
else:
return cls.original_forward(input)

View File

@ -464,6 +464,56 @@ class MochiDecode:
return (frames,)
class MochiDecodeSpatialTiling:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("MOCHIVAE",),
"samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
min_block_size):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device)
B, C, T, H, W = samples.shape
vae.to(device)
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling:
from .mochi_preview.vae.model import apply_tiled
logging.warning("Decoding with tiling...")
frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
else:
logging.info("Decoding without tiling...")
frames = vae(samples)
vae.to(offload_device)
frames = frames.float()
frames = (frames + 1.0) / 2.0
frames.clamp_(0.0, 1.0)
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
return (frames,)
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
@ -472,6 +522,7 @@ NODE_CLASS_MAPPINGS = {
"MochiTextEncode": MochiTextEncode,
"MochiModelLoader": MochiModelLoader,
"MochiVAELoader": MochiVAELoader,
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -480,4 +531,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiTextEncode": "Mochi TextEncode",
"MochiModelLoader": "Mochi Model Loader",
"MochiVAELoader": "Mochi VAE Loader",
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
}