diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 385eefd..9514df1 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -550,6 +550,7 @@ class Decoder(nn.Module): nonlinearity: str = "silu", output_nonlinearity: str = "silu", causal: bool = True, + dtype: torch.dtype = torch.float32, **block_kwargs, ): super().__init__() @@ -558,6 +559,7 @@ class Decoder(nn.Module): self.channel_multipliers = channel_multipliers self.num_res_blocks = num_res_blocks self.output_nonlinearity = output_nonlinearity + self.dtype = dtype assert nonlinearity == "silu" assert causal @@ -718,18 +720,27 @@ def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tens def nearest_multiple(x: int, multiple: int) -> int: return round(x / multiple) * multiple - +from tqdm import tqdm +from comfy.utils import ProgressBar def apply_tiled( fn: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, num_tiles_w: int, num_tiles_h: int, - overlap: int = 0, # Number of pixel of overlap between adjacent tiles. - # Use a factor of 2 times the latent downsample factor. + overlap: int = 0, # Number of pixels of overlap between adjacent tiles. min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing. + pbar: Optional[tqdm] = None, + comfy_pbar: Optional[ProgressBar] = None, ): + if pbar is None: + total_tiles = num_tiles_w * num_tiles_h + pbar = tqdm(total=total_tiles) + comfy_pbar = ProgressBar(total_tiles) if num_tiles_w == 1 and num_tiles_h == 1: - return fn(x) + result = fn(x) + pbar.update(1) + comfy_pbar.update(1) + return result assert ( num_tiles_w & (num_tiles_w - 1) == 0 @@ -752,10 +763,10 @@ def apply_tiled( assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even" left = apply_tiled( - fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar ) right = apply_tiled( - fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size, pbar, comfy_pbar ) if left is None or right is None: return None @@ -774,10 +785,10 @@ def apply_tiled( assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even" top = apply_tiled( - fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar ) bottom = apply_tiled( - fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size, pbar, comfy_pbar ) if top is None or bottom is None: return None diff --git a/nodes.py b/nodes.py index b58fbab..fe9f3ad 100644 --- a/nodes.py +++ b/nodes.py @@ -240,6 +240,7 @@ class MochiVAELoader: }, "optional": { "torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), + "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}), }, } @@ -248,12 +249,14 @@ class MochiVAELoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, torch_compile_args=None): + def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + vae_path = folder_paths.get_full_path_or_raise("vae", model_name) with (init_empty_weights() if is_accelerate_available else nullcontext()): @@ -271,25 +274,22 @@ class MochiVAELoader: nonlinearity="silu", output_nonlinearity="silu", causal=True, + dtype=dtype, ) vae_sd = load_torch_file(vae_path) if is_accelerate_available: - for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.float32, device=offload_device, value=vae_sd[key]) + for name, param in vae.named_parameters(): + set_module_tensor_to_device(vae, name, dtype=dtype, device=offload_device, value=vae_sd[name]) else: vae.load_state_dict(vae_sd, strict=True) - vae.to(torch.bfloat16).to("cpu") + vae.to(dtype).to(offload_device) vae.eval() del vae_sd if torch_compile_args is not None: vae.to(device) - # for i, block in enumerate(vae.blocks): - # if "CausalUpsampleBlock" in str(type(block)): - # print("Compiling block", block) vae = torch.compile(vae, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"]) - return (vae,) class MochiTextEncode: @@ -447,7 +447,7 @@ class MochiDecode: offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] - samples = samples.to(torch.bfloat16).to(device) + samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape @@ -574,21 +574,27 @@ class MochiDecodeSpatialTiling: offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] - samples = samples.to(torch.bfloat16).to(device) + samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape + vae.to(device) decoded_list = [] with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): if enable_vae_tiling: from .mochi_preview.vae.model import apply_tiled - logging.warning("Decoding with tiling...") + + pbar = ProgressBar(T // per_batch) for i in range(0, T, per_batch): + if i >= T: + break end_index = min(i + per_batch, T) - chunk = samples[:, :, i:end_index, :, :] + logging.info(f"Decoding {end_index - i} samples with tiling...") + chunk = samples[:, :, i:end_index, :, :] frames = apply_tiled(vae, chunk, num_tiles_w = num_tiles_w, num_tiles_h = num_tiles_h, overlap=overlap, min_block_size=min_block_size) - print(frames.shape) + logging.info(f"Decoded {frames.shape[2]} frames from {end_index - i} samples") + pbar.update(1) # Blend the first and last frames of each pair if len(decoded_list) > 0: previous_frames = decoded_list[-1]