Add alternative VAE decoding node
This was actually unused code in the VAE model, only does spatial tiling though, but seams look better
This commit is contained in:
parent
4c72322e8b
commit
36a4275b3b
@ -38,7 +38,6 @@ def fp8_linear_forward(cls, original_dtype, input):
|
||||
else:
|
||||
cls.to(original_dtype)
|
||||
out = cls.original_forward(input.to(original_dtype))
|
||||
cls.to(original_dtype)
|
||||
return out
|
||||
else:
|
||||
return cls.original_forward(input)
|
||||
|
||||
Binary file not shown.
52
nodes.py
52
nodes.py
@ -464,6 +464,56 @@ class MochiDecode:
|
||||
|
||||
return (frames,)
|
||||
|
||||
class MochiDecodeSpatialTiling:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"vae": ("MOCHIVAE",),
|
||||
"samples": ("LATENT", ),
|
||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
|
||||
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
|
||||
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||
min_block_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(torch.bfloat16).to(device)
|
||||
|
||||
B, C, T, H, W = samples.shape
|
||||
|
||||
vae.to(device)
|
||||
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
if enable_vae_tiling:
|
||||
from .mochi_preview.vae.model import apply_tiled
|
||||
logging.warning("Decoding with tiling...")
|
||||
frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||
else:
|
||||
logging.info("Decoding without tiling...")
|
||||
frames = vae(samples)
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
frames = frames.float()
|
||||
frames = (frames + 1.0) / 2.0
|
||||
frames.clamp_(0.0, 1.0)
|
||||
|
||||
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
|
||||
|
||||
return (frames,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||
@ -472,6 +522,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MochiTextEncode": MochiTextEncode,
|
||||
"MochiModelLoader": MochiModelLoader,
|
||||
"MochiVAELoader": MochiVAELoader,
|
||||
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||
@ -480,4 +531,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MochiTextEncode": "Mochi TextEncode",
|
||||
"MochiModelLoader": "Mochi Model Loader",
|
||||
"MochiVAELoader": "Mochi VAE Loader",
|
||||
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user