diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 45805a7..de6ea8d 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,17 +1,14 @@ import json -import random from typing import Dict, List -import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F import torch.utils.data -from torch import nn from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file +import comfy.model_management as mm from contextlib import nullcontext try: @@ -86,29 +83,6 @@ def compute_packed_indices( "valid_token_indices_kv": valid_token_indices, } - -def shift_sigma( - sigma: np.ndarray, - shift: float, -): - """Shift noise standard deviation toward higher values. - - Useful for training a model at high resolutions, - or sampling more finely at high noise levels. - - Equivalent to: - sigma_shift = shift / (shift + 1 / sigma - 1) - except for sigma = 0. - - Args: - sigma: noise standard deviation in [0, 1] - shift: shift factor >= 1. - For shift > 1, shifts sigma to higher values. - For shift = 1, identity function. - """ - return shift * sigma / (shift * sigma + 1 - sigma) - - class T2VSynthMochiModel: def __init__( self, @@ -239,23 +213,16 @@ class T2VSynthMochiModel: @torch.inference_mode(mode=True) def run(self, args, stream_results): - random.seed(args["seed"]) - np.random.seed(args["seed"]) torch.manual_seed(args["seed"]) + torch.cuda.manual_seed(args["seed"]) generator = torch.Generator(device=self.device) generator.manual_seed(args["seed"]) - # assert ( - # len(args["prompt"]) == 1 - # ), f"Expected exactly one prompt, got {len(args['prompt'])}" - #prompt = args["prompt"][0] - #neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else "" - B = 1 - - w = args["width"] - h = args["height"] - t = args["num_frames"] + num_frames = args["num_frames"] + height = args["height"] + width = args["width"] + batch_cfg = args["mochi_args"]["batch_cfg"] sample_steps = args["mochi_args"]["num_inference_steps"] cfg_schedule = args["mochi_args"].get("cfg_schedule") @@ -267,7 +234,7 @@ class T2VSynthMochiModel: assert ( len(sigma_schedule) == sample_steps + 1 ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" - assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}" + assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}" # if batch_cfg: # sample_batched = self.get_conditioning( @@ -277,15 +244,19 @@ class T2VSynthMochiModel: # sample = self.get_conditioning([prompt], zero_last_n_prompts=0) # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) + # create z spatial_downsample = 8 temporal_downsample = 6 - latent_t = (t - 1) // temporal_downsample + 1 - latent_w, latent_h = w // spatial_downsample, h // spatial_downsample - - latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h) in_channels = 12 + B = 1 + C = in_channels + T = (num_frames - 1) // temporal_downsample + 1 + H = height // spatial_downsample + W = width // spatial_downsample + latent_dims = dict(lT=T, lW=W, lH=H) + z = torch.randn( - (B, in_channels, latent_t, latent_h, latent_w), + (B, C, T, H, W), device=self.device, generator=generator, dtype=torch.float32, @@ -307,22 +278,6 @@ class T2VSynthMochiModel: "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] } - # print(sample["y_mask"]) - # print(type(sample["y_mask"])) - # print(sample["y_mask"][0].shape) - - # print(sample["y_feat"]) - # print(type(sample["y_feat"])) - # print(sample["y_feat"][0].shape) - - # print(sample_null["y_mask"]) - # print(type(sample_null["y_mask"])) - # print(sample_null["y_mask"][0].shape) - - # print(sample_null["y_feat"]) - # print(type(sample_null["y_feat"])) - # print(sample_null["y_feat"][0].shape) - sample["packed_indices"] = self.get_packed_indices( sample["y_mask"], **latent_dims ) @@ -331,8 +286,6 @@ class T2VSynthMochiModel: ) def model_fn(*, z, sigma, cfg_scale): - #print("z", z.dtype, z.device) - #print("sigma", sigma.dtype, sigma.device) self.dit.to(self.device) # if batch_cfg: # with torch.autocast("cuda", dtype=torch.bfloat16): @@ -341,7 +294,7 @@ class T2VSynthMochiModel: #else: nonlocal sample, sample_null - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): out_cond = self.dit(z, sigma, **sample) out_uncond = self.dit(z, sigma, **sample_null) assert out_cond.shape == out_uncond.shape @@ -364,8 +317,6 @@ class T2VSynthMochiModel: pred = pred.to(z) output_cond = output_cond.to(z) - #if stream_results: - # yield i / sample_steps, None, False z = z + dsigma * pred comfy_pbar.update(1) diff --git a/nodes.py b/nodes.py index af6e74f..43aa1b2 100644 --- a/nodes.py +++ b/nodes.py @@ -248,7 +248,7 @@ class MochiDecode: "samples": ("LATENT", ), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), - "frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1}), + "frame_batch_size": ("INT", {"default": 6, "min": 1, "max": 64, "step": 1, "tooltip": "Number of frames in latent space (downscale factor is 6) to decode at once"}), "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), @@ -268,6 +268,17 @@ class MochiDecode: samples = samples["samples"] samples = samples.to(torch.bfloat16).to(device) + B, C, T, H, W = samples.shape + + self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 + self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 + + self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8 + self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8 + + self.tile_latent_min_height = int(self.tile_sample_min_height / 8) + self.tile_latent_min_width = int(self.tile_sample_min_width / 8) + def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) @@ -284,70 +295,68 @@ class MochiDecode: x / blend_extent ) return b + + def decode_tiled(samples): + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width - self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 - self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + comfy_pbar = ProgressBar(len(range(0, H, overlap_height))) + rows = [] + for i in tqdm(range(0, H, overlap_height), desc="Processing rows"): + row = [] + for j in tqdm(range(0, W, overlap_width), desc="Processing columns", leave=False): + time = [] + for k in tqdm(range(T // frame_batch_size), desc="Processing frames", leave=False): + remaining_frames = T % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = samples[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + tile = vae(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + comfy_pbar.update(1) - self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 - self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 + result_rows = [] + for i, row in enumerate(tqdm(rows, desc="Blending rows")): + result_row = [] + for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) - self.tile_latent_min_height = int(self.tile_sample_min_height / 8) - self.tile_latent_min_width = int(self.tile_sample_min_width / 8) + return torch.cat(result_rows, dim=3) vae.to(device) - with torch.amp.autocast("cuda", dtype=torch.bfloat16): - if not enable_vae_tiling: + with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + if enable_vae_tiling and frame_batch_size > T: + logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling") + samples = vae(samples) + elif not enable_vae_tiling: + logging.warning("Attempting to decode without tiling, very memory intensive") samples = vae(samples) else: - batch_size, num_channels, num_frames, height, width = samples.shape - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - comfy_pbar = ProgressBar(len(range(0, height, overlap_height))) - rows = [] - for i in tqdm(range(0, height, overlap_height), desc="Processing rows"): - row = [] - for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False): - time = [] - for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = samples[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ] - tile = vae(tile) - time.append(tile) - row.append(torch.cat(time, dim=2)) - rows.append(row) - comfy_pbar.update(1) - - result_rows = [] - for i, row in enumerate(tqdm(rows, desc="Blending rows")): - result_row = [] - for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = blend_v(rows[i - 1][j], tile, blend_extent_height) - if j > 0: - tile = blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) - result_rows.append(torch.cat(result_row, dim=4)) - - samples = torch.cat(result_rows, dim=3) + logging.info("Decoding with tiling") + samples = decode_tiled(samples) + vae.to(offload_device) - #print("samples", samples.shape, samples.dtype, samples.device) samples = samples.float() samples = (samples + 1.0) / 2.0