Add alternative VAE decoding node
This was actually unused code in the VAE model, only does spatial tiling though, but seams look better
This commit is contained in:
parent
4c72322e8b
commit
36a4275b3b
@ -30,7 +30,7 @@ def fp8_linear_forward(cls, original_dtype, input):
|
|||||||
|
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input.shape[0], -1)
|
return o.reshape(input.shape[0], -1)
|
||||||
|
|
||||||
@ -38,7 +38,6 @@ def fp8_linear_forward(cls, original_dtype, input):
|
|||||||
else:
|
else:
|
||||||
cls.to(original_dtype)
|
cls.to(original_dtype)
|
||||||
out = cls.original_forward(input.to(original_dtype))
|
out = cls.original_forward(input.to(original_dtype))
|
||||||
cls.to(original_dtype)
|
|
||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
return cls.original_forward(input)
|
return cls.original_forward(input)
|
||||||
|
|||||||
Binary file not shown.
52
nodes.py
52
nodes.py
@ -464,6 +464,56 @@ class MochiDecode:
|
|||||||
|
|
||||||
return (frames,)
|
return (frames,)
|
||||||
|
|
||||||
|
class MochiDecodeSpatialTiling:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"vae": ("MOCHIVAE",),
|
||||||
|
"samples": ("LATENT", ),
|
||||||
|
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||||
|
"num_tiles_w": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of horizontal tiles"}),
|
||||||
|
"num_tiles_h": ("INT", {"default": 4, "min": 2, "max": 64, "step": 2, "tooltip": "Number of vertical tiles"}),
|
||||||
|
"overlap": ("INT", {"default": 16, "min": 0, "max": 256, "step": 1, "tooltip": "Number of pixel of overlap between adjacent tiles"}),
|
||||||
|
"min_block_size": ("INT", {"default": 1, "min": 1, "max": 256, "step": 1, "tooltip": "Minimum number of pixels in each dimension when subdividing"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
|
def decode(self, vae, samples, enable_vae_tiling, num_tiles_w, num_tiles_h, overlap,
|
||||||
|
min_block_size):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
intermediate_device = mm.intermediate_device()
|
||||||
|
samples = samples["samples"]
|
||||||
|
samples = samples.to(torch.bfloat16).to(device)
|
||||||
|
|
||||||
|
B, C, T, H, W = samples.shape
|
||||||
|
|
||||||
|
vae.to(device)
|
||||||
|
|
||||||
|
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||||
|
if enable_vae_tiling:
|
||||||
|
from .mochi_preview.vae.model import apply_tiled
|
||||||
|
logging.warning("Decoding with tiling...")
|
||||||
|
frames = apply_tiled(vae, samples, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size)
|
||||||
|
else:
|
||||||
|
logging.info("Decoding without tiling...")
|
||||||
|
frames = vae(samples)
|
||||||
|
|
||||||
|
vae.to(offload_device)
|
||||||
|
|
||||||
|
frames = frames.float()
|
||||||
|
frames = (frames + 1.0) / 2.0
|
||||||
|
frames.clamp_(0.0, 1.0)
|
||||||
|
|
||||||
|
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
|
||||||
|
|
||||||
|
return (frames,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||||
@ -472,6 +522,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"MochiTextEncode": MochiTextEncode,
|
"MochiTextEncode": MochiTextEncode,
|
||||||
"MochiModelLoader": MochiModelLoader,
|
"MochiModelLoader": MochiModelLoader,
|
||||||
"MochiVAELoader": MochiVAELoader,
|
"MochiVAELoader": MochiVAELoader,
|
||||||
|
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||||
@ -480,4 +531,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"MochiTextEncode": "Mochi TextEncode",
|
"MochiTextEncode": "Mochi TextEncode",
|
||||||
"MochiModelLoader": "Mochi Model Loader",
|
"MochiModelLoader": "Mochi Model Loader",
|
||||||
"MochiVAELoader": "Mochi VAE Loader",
|
"MochiVAELoader": "Mochi VAE Loader",
|
||||||
|
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user