diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..2e96bd6 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/__pycache__/__init__.cpython-312.pyc b/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..f6ee936 Binary files /dev/null and b/__pycache__/__init__.cpython-312.pyc differ diff --git a/__pycache__/nodes.cpython-312.pyc b/__pycache__/nodes.cpython-312.pyc new file mode 100644 index 0000000..fba7b54 Binary files /dev/null and b/__pycache__/nodes.cpython-312.pyc differ diff --git a/configs/vae_stats.json b/configs/vae_stats.json new file mode 100644 index 0000000..e3278af --- /dev/null +++ b/configs/vae_stats.json @@ -0,0 +1,4 @@ +{ + "mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285], + "std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041] +} diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..da83061 --- /dev/null +++ b/infer.py @@ -0,0 +1,213 @@ +import json +import os +import tempfile +import time + +import click +import numpy as np +#import ray +from einops import rearrange +from PIL import Image +from tqdm import tqdm + +from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel + +model = None +model_path = "weights" +def noexcept(f): + try: + return f() + except: + pass +# class MochiWrapper: +# def __init__(self, *, num_workers, **actor_kwargs): +# super().__init__() +# RemoteClass = ray.remote(T2VSynthMochiModel) +# self.workers = [ +# RemoteClass.options(num_gpus=1).remote( +# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs +# ) +# for i in range(num_workers) +# ] +# # Ensure the __init__ method has finished on all workers +# for worker in self.workers: +# ray.get(worker.__ray_ready__.remote()) +# self.is_loaded = True + +# def __call__(self, args): +# work_refs = [ +# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers) +# ] + +# try: +# for result in work_refs[0]: +# yield ray.get(result) + +# # Handle the (very unlikely) edge-case where a worker that's not the 1st one +# # fails (don't want an uncaught error) +# for result in work_refs[1:]: +# ray.get(result) +# except Exception as e: +# # Get exception from other workers +# for ref in work_refs[1:]: +# noexcept(lambda: ray.get(ref)) +# raise e + +def set_model_path(path): + global model_path + model_path = path + + +def load_model(): + global model, model_path + if model is None: + #ray.init() + MOCHI_DIR = model_path + VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors" + MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml" + MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors" + + model = T2VSynthMochiModel( + device_id=0, + world_size=1, + local_rank=0, + vae_stats_path=f"{MOCHI_DIR}/vae_stats.json", + vae_checkpoint_path=VAE_CHECKPOINT_PATH, + dit_config_path=MODEL_CONFIG_PATH, + dit_checkpoint_path=MODEL_CHECKPOINT_PATH, + ) + +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) + const = quadratic_coef * (linear_steps ** 2) + quadratic_sigma_schedule = [ + quadratic_coef * (i ** 2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + +def generate_video( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_inference_steps, +): + load_model() + + # sigma_schedule should be a list of floats of length (num_inference_steps + 1), + # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing. + sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025) + + # cfg_schedule should be a list of floats of length num_inference_steps. + # For simplicity, we just use the same cfg scale at all timesteps, + # but more optimal schedules may use varying cfg, e.g: + # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2) + cfg_schedule = [cfg_scale] * num_inference_steps + + args = { + "height": height, + "width": width, + "num_frames": num_frames, + "mochi_args": { + "sigma_schedule": sigma_schedule, + "cfg_schedule": cfg_schedule, + "num_inference_steps": num_inference_steps, + "batch_cfg": False, + }, + "prompt": [prompt], + "negative_prompt": [negative_prompt], + "seed": seed, + } + + final_frames = None + for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1): + final_frames = frames + + assert isinstance(final_frames, np.ndarray) + assert final_frames.dtype == np.float32 + + final_frames = rearrange(final_frames, "t b h w c -> b t h w c") + final_frames = final_frames[0] + + os.makedirs("outputs", exist_ok=True) + output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4") + + with tempfile.TemporaryDirectory() as tmpdir: + frame_paths = [] + for i, frame in enumerate(final_frames): + frame = (frame * 255).astype(np.uint8) + frame_img = Image.fromarray(frame) + frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png") + frame_img.save(frame_path) + frame_paths.append(frame_path) + + frame_pattern = os.path.join(tmpdir, "frame_%04d.png") + ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}" + os.system(ffmpeg_cmd) + + json_path = os.path.splitext(output_path)[0] + ".json" + with open(json_path, "w") as f: + json.dump(args, f, indent=4) + + return output_path + + +@click.command() +@click.option("--prompt", default=""" + a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes, + rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and + diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider + resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body, + with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive, + with the wind subtly blowing sand grains across the frame.""" + , required=False, help="Prompt for video generation.") +@click.option( + "--negative_prompt", default="", help="Negative prompt for video generation." +) +@click.option("--width", default=848, type=int, help="Width of the video.") +@click.option("--height", default=480, type=int, help="Height of the video.") +@click.option("--num_frames", default=163, type=int, help="Number of frames.") +@click.option("--seed", default=12345, type=int, help="Random seed.") +@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.") +@click.option( + "--num_steps", default=64, type=int, help="Number of inference steps." +) +@click.option("--model_dir", required=True, help="Path to the model directory.") +def generate_cli( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_steps, + model_dir, +): + set_model_path(model_dir) + output = generate_video( + prompt, + negative_prompt, + width, + height, + num_frames, + seed, + cfg_scale, + num_steps, + ) + click.echo(f"Video generated at: {output}") + +if __name__ == "__main__": + generate_cli() diff --git a/mochi_preview/__init__.py b/mochi_preview/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/__pycache__/__init__.cpython-311.pyc b/mochi_preview/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..d8e3adf Binary files /dev/null and b/mochi_preview/__pycache__/__init__.cpython-311.pyc differ diff --git a/mochi_preview/__pycache__/__init__.cpython-312.pyc b/mochi_preview/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..0838524 Binary files /dev/null and b/mochi_preview/__pycache__/__init__.cpython-312.pyc differ diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc new file mode 100644 index 0000000..d208bdd Binary files /dev/null and b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc differ diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc new file mode 100644 index 0000000..4b28b7e Binary files /dev/null and b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc differ diff --git a/mochi_preview/__pycache__/utils.cpython-311.pyc b/mochi_preview/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..3d3564b Binary files /dev/null and b/mochi_preview/__pycache__/utils.cpython-311.pyc differ diff --git a/mochi_preview/__pycache__/utils.cpython-312.pyc b/mochi_preview/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..4017fe1 Binary files /dev/null and b/mochi_preview/__pycache__/utils.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__init__.py b/mochi_preview/dit/joint_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..4270e08 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..afd29b5 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/__init__.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc new file mode 100644 index 0000000..15bbc00 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc new file mode 100644 index 0000000..63e495a Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc new file mode 100644 index 0000000..964d807 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc new file mode 100644 index 0000000..dce93f9 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/context_parallel.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000..d6c2ed5 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000..9672afb Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc new file mode 100644 index 0000000..5812523 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc new file mode 100644 index 0000000..4de68dd Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/mod_rmsnorm.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc new file mode 100644 index 0000000..13cd096 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc new file mode 100644 index 0000000..39f734c Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/residual_tanh_gated_rmsnorm.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc new file mode 100644 index 0000000..3435b9e Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-312.pyc new file mode 100644 index 0000000..f650a00 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/rope_mixed.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc new file mode 100644 index 0000000..9c9187d Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-312.pyc new file mode 100644 index 0000000..37405a5 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/temporal_rope.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc b/mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..d34b936 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..514a891 Binary files /dev/null and b/mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py new file mode 100644 index 0000000..7c319bc --- /dev/null +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -0,0 +1,675 @@ +import os +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from torch.nn.attention import sdpa_kernel + +from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active +from .layers import ( + FeedForward, + PatchEmbed, + RMSNorm, + TimestepEmbedder, +) +from .mod_rmsnorm import modulated_rmsnorm +from .residual_tanh_gated_rmsnorm import ( + residual_tanh_gated_rmsnorm, +) +from .rope_mixed import ( + compute_mixed_rotation, + create_position_matrix, +) +from .temporal_rope import apply_rotary_emb_qk_real +from .utils import ( + AttentionPool, + modulate, + pad_and_split_xy, + unify_streams, +) + +try: + from flash_attn import flash_attn_varlen_qkvpacked_func + FLASH_ATTN_IS_AVAILABLE = True +except ImportError: + FLASH_ATTN_IS_AVAILABLE = False + +COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" +COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" + + +class AsymmetricAttention(nn.Module): + def __init__( + self, + dim_x: int, + dim_y: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0.0, + update_y: bool = True, + out_bias: bool = True, + attend_to_padding: bool = False, + softmax_scale: Optional[float] = None, + device: Optional[torch.device] = None, + clip_feat_dim: Optional[int] = None, + pooled_caption_mlp_bias: bool = True, + use_transformer_engine: bool = False, + ): + super().__init__() + self.dim_x = dim_x + self.dim_y = dim_y + self.num_heads = num_heads + self.head_dim = dim_x // num_heads + self.attn_drop = attn_drop + self.update_y = update_y + self.attend_to_padding = attend_to_padding + self.softmax_scale = softmax_scale + if dim_x % num_heads != 0: + raise ValueError( + f"dim_x={dim_x} should be divisible by num_heads={num_heads}" + ) + + # Input layers. + self.qkv_bias = qkv_bias + self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device) + # Project text features to match visual features (dim_y -> dim_x) + self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device) + + # Query and key normalization for stability. + assert qk_norm + self.q_norm_x = RMSNorm(self.head_dim, device=device) + self.k_norm_x = RMSNorm(self.head_dim, device=device) + self.q_norm_y = RMSNorm(self.head_dim, device=device) + self.k_norm_y = RMSNorm(self.head_dim, device=device) + + # Output layers. y features go back down from dim_x -> dim_y. + self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device) + self.proj_y = ( + nn.Linear(dim_x, dim_y, bias=out_bias, device=device) + if update_y + else nn.Identity() + ) + + def run_qkv_y(self, y): + cp_rank, cp_size = get_cp_rank_size() + local_heads = self.num_heads // cp_size + + if is_cp_active(): + # Only predict local heads. + assert not self.qkv_bias + W_qkv_y = self.qkv_y.weight.view( + 3, self.num_heads, self.head_dim, self.dim_y + ) + W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) + W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) + qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) + else: + qkv_y = self.qkv_y(y) # (B, L, 3 * dim) + + qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) + q_y, k_y, v_y = qkv_y.unbind(2) + return q_y, k_y, v_y + + def prepare_qkv( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, + scale_y: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + valid_token_indices: torch.Tensor, + ): + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + #print("x in attn", x.dtype, x.device) + + # Process visual features + qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) + assert qkv_x.dtype == torch.bfloat16 + qkv_x = all_to_all_collect_tokens( + qkv_x, self.num_heads + ) # (3, B, N, local_h, head_dim) + + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) + #print("y in attn", y.dtype, y.device) + #print(q_y.dtype, q_y.device) + #print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device) + # self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype) + # self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype) + # self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype) + # self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype) + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) + + # Split qkv_x into q, k, v + q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) + + # Unite streams + qkv = unify_streams( + q_x, + k_x, + v_x, + q_y, + k_y, + v_y, + valid_token_indices, + ) + + return qkv + + @torch.compiler.disable() + def run_attention( + self, + qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) + *, + B: int, + L: int, + M: int, + cu_seqlens: torch.Tensor, + max_seqlen_in_batch: int, + valid_token_indices: torch.Tensor, + ): + _, cp_size = get_cp_rank_size() + N = cp_size * M + assert self.num_heads % cp_size == 0 + local_heads = self.num_heads // cp_size + local_dim = local_heads * self.head_dim + total = qkv.size(0) + + if FLASH_ATTN_IS_AVAILABLE: + with torch.autocast("cuda", enabled=False): + out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + ) # (total, local_heads, head_dim) + out = out.view(total, local_dim) + else: + raise NotImplementedError("Flash attention is currently required.") + print("qkv: ",qkv.shape, qkv.dtype, qkv.device) + expected_size = 2 * 44520 * 3 * 24 * 128 + actual_size = qkv.numel() + print(f"Expected size: {expected_size}, Actual size: {actual_size}") + q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4) + with torch.autocast("cuda", enabled=False): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim) + + x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) + assert x.size() == (B, N, local_dim) + assert y.size() == (B, L, local_dim) + + x = x.view(B, N, local_heads, self.head_dim) + x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim) + x = self.proj_x(x) # (B, M, dim_x) + + if is_cp_active(): + y = all_gather(y) # (cp_size * B, L, local_heads * head_dim) + y = rearrange( + y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim + ) # (B, L, dim_x) + y = self.proj_y(y) # (B, L, dim_y) + return x, y + + def forward( + self, + x: torch.Tensor, # (B, N, dim_x) + y: torch.Tensor, # (B, L, dim_y) + *, + scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. + scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. + packed_indices: Dict[str, torch.Tensor] = None, + **rope_rotation, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of asymmetric multi-modal attention. + + Args: + x: (B, N, dim_x) tensor for visual tokens + y: (B, L, dim_y) tensor of text token features + packed_indices: Dict with keys for Flash Attention + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim_x) tensor of visual tokens after multi-modal attention + y: (B, L, dim_y) tensor of text token features after multi-modal attention + """ + B, L, _ = y.shape + _, M, _ = x.shape + + # Predict a packed QKV tensor from visual and text features. + # Don't checkpoint the all_to_all. + qkv = self.prepare_qkv( + x=x, + y=y, + scale_x=scale_x, + scale_y=scale_y, + rope_cos=rope_rotation.get("rope_cos"), + rope_sin=rope_rotation.get("rope_sin"), + valid_token_indices=packed_indices["valid_token_indices_kv"], + ) # (total <= B * (N + L), 3, local_heads, head_dim) + + x, y = self.run_attention( + qkv, + B=B, + L=L, + M=M, + cu_seqlens=packed_indices["cu_seqlens_kv"], + max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], + valid_token_indices=packed_indices["valid_token_indices_kv"], + ) + return x, y + +#@torch.compile(disable=not COMPILE_MMDIT_BLOCK) +class AsymmetricJointBlock(nn.Module): + def __init__( + self, + hidden_size_x: int, + hidden_size_y: int, + num_heads: int, + *, + mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens. + mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. + update_y: bool = True, # Whether to update text tokens in this block. + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + self.update_y = update_y + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device) + if self.update_y: + self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) + else: + self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) + + # Self-attention: + self.attn = AsymmetricAttention( + hidden_size_x, + hidden_size_y, + num_heads=num_heads, + update_y=update_y, + device=device, + **block_kwargs, + ) + + # MLP. + mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x) + assert mlp_hidden_dim_x == int(1536 * 8) + self.mlp_x = FeedForward( + in_features=hidden_size_x, + hidden_size=mlp_hidden_dim_x, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + # MLP for text not needed in last block. + if self.update_y: + mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y) + self.mlp_y = FeedForward( + in_features=hidden_size_y, + hidden_size=mlp_hidden_dim_y, + multiple_of=256, + ffn_dim_multiplier=None, + device=device, + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + y: torch.Tensor, + **attn_kwargs, + ): + """Forward pass of a block. + + Args: + x: (B, N, dim) tensor of visual tokens + c: (B, dim) tensor of conditioned features + y: (B, L, dim) tensor of text tokens + num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + + Returns: + x: (B, N, dim) tensor of visual tokens after block + y: (B, L, dim) tensor of text tokens after block + """ + N = x.size(1) + + c = F.silu(c) + mod_x = self.mod_x(c) + scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) + + mod_y = self.mod_y(c) + if self.update_y: + scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) + else: + scale_msa_y = mod_y + + # Self-attention block. + x_attn, y_attn = self.attn( + x, + y, + scale_x=scale_msa_x, + scale_y=scale_msa_y, + **attn_kwargs, + ) + + assert x_attn.size(1) == N + x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) + if self.update_y: + y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) + + # MLP block. + x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) + if self.update_y: + y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) + + return x, y + + def ff_block_x(self, x, scale_x, gate_x): + x_mod = modulated_rmsnorm(x, scale_x) + x_res = self.mlp_x(x_mod) + x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm + return x + + def ff_block_y(self, y, scale_y, gate_y): + y_mod = modulated_rmsnorm(y, scale_y) + y_res = self.mlp_y(y_mod) + y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm + return y + + +#@torch.compile(disable=not COMPILE_FINAL_LAYER) +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size, + patch_size, + out_channels, + device: Optional[torch.device] = None, + ): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, device=device + ) + self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, device=device + ) + + def forward(self, x, c): + c = F.silu(c) + shift, scale = self.mod(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AsymmDiTJoint(nn.Module): + """ + Diffusion model with a Transformer backbone. + + Ingests text embeddings instead of a label. + """ + + def __init__( + self, + *, + patch_size=2, + in_channels=4, + hidden_size_x=1152, + hidden_size_y=1152, + depth=48, + num_heads=16, + mlp_ratio_x=8.0, + mlp_ratio_y=4.0, + t5_feat_dim: int = 4096, + t5_token_length: int = 256, + patch_embed_bias: bool = True, + timestep_mlp_bias: bool = True, + timestep_scale: Optional[float] = None, + use_extended_posenc: bool = False, + rope_theta: float = 10000.0, + device: Optional[torch.device] = None, + **block_kwargs, + ): + super().__init__() + + + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size_x = hidden_size_x + self.hidden_size_y = hidden_size_y + self.head_dim = ( + hidden_size_x // num_heads + ) # Head dimension and count is determined by visual. + self.use_extended_posenc = use_extended_posenc + self.t5_token_length = t5_token_length + self.t5_feat_dim = t5_feat_dim + self.rope_theta = ( + rope_theta # Scaling factor for frequency computation for temporal RoPE. + ) + + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size_x, + bias=patch_embed_bias, + device=device, + ) + # Conditionings + # Timestep + self.t_embedder = TimestepEmbedder( + hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale + ) + + # Caption Pooling (T5) + self.t5_y_embedder = AttentionPool( + t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device + ) + + # Dense Embedding Projection (T5) + self.t5_yproj = nn.Linear( + t5_feat_dim, hidden_size_y, bias=True, device=device + ) + + # Initialize pos_frequencies as an empty parameter. + self.pos_frequencies = nn.Parameter( + torch.empty(3, self.num_heads, self.head_dim // 2, device=device) + ) + + # for depth 48: + # b = 0: AsymmetricJointBlock, update_y=True + # b = 1: AsymmetricJointBlock, update_y=True + # ... + # b = 46: AsymmetricJointBlock, update_y=True + # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features. + blocks = [] + for b in range(depth): + # Joint multi-modal block + update_y = b < depth - 1 + block = AsymmetricJointBlock( + hidden_size_x, + hidden_size_y, + num_heads, + mlp_ratio_x=mlp_ratio_x, + mlp_ratio_y=mlp_ratio_y, + update_y=update_y, + device=device, + **block_kwargs, + ) + + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + self.final_layer = FinalLayer( + hidden_size_x, patch_size, self.out_channels, device=device + ) + + def embed_x(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C=12, T, H, W) tensor of visual tokens + + Returns: + x: (B, C=3072, N) tensor of visual tokens with positional embedding. + """ + return self.x_embedder(x) # Convert BcTHW to BCN + + #@torch.compile(disable=not COMPILE_MMDIT_BLOCK) + def prepare( + self, + x: torch.Tensor, + sigma: torch.Tensor, + t5_feat: torch.Tensor, + t5_mask: torch.Tensor, + ): + """Prepare input and conditioning embeddings.""" + #("X", x.shape) + with torch.profiler.record_function("x_emb_pe"): + # Visual patch embeddings with positional encoding. + T, H, W = x.shape[-3:] + pH, pW = H // self.patch_size, W // self.patch_size + x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 + assert x.ndim == 3 + B = x.size(0) + + with torch.profiler.record_function("rope_cis"): + # Construct position array of size [N, 3]. + # pos[:, 0] is the frame index for each location, + # pos[:, 1] is the row index for each location, and + # pos[:, 2] is the column index for each location. + pH, pW = H // self.patch_size, W // self.patch_size + N = T * pH * pW + assert x.size(1) == N + pos = create_position_matrix( + T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 + ) # (N, 3) + rope_cos, rope_sin = compute_mixed_rotation( + freqs=self.pos_frequencies, pos=pos + ) # Each are (N, num_heads, dim // 2) + + with torch.profiler.record_function("t_emb"): + # Global vector embedding for conditionings. + c_t = self.t_embedder(1 - sigma) # (B, D) + + with torch.profiler.record_function("t5_pool"): + # Pool T5 tokens using attention pooler + # Note y_feat[1] contains T5 token features. + # print("B", B) + # print("t5 feat shape",t5_feat.shape) + # print("t5 mask shape", t5_mask.shape) + assert ( + t5_feat.size(1) == self.t5_token_length + ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." + t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) + assert ( + t5_y_pool.size(0) == B + ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + + c = c_t + t5_y_pool + + y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D) + + return x, c, y_feat, rope_cos, rope_sin + + def forward( + self, + x: torch.Tensor, + sigma: torch.Tensor, + y_feat: List[torch.Tensor], + y_mask: List[torch.Tensor], + packed_indices: Dict[str, torch.Tensor] = None, + rope_cos: torch.Tensor = None, + rope_sin: torch.Tensor = None, + ): + """Forward pass of DiT. + + Args: + x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) + sigma: (B,) tensor of noise standard deviations + y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) + y_mask: List((B, L) boolean tensor indicating which tokens are not padding) + packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. + """ + B, _, T, H, W = x.shape + + # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. + # Have to call sdpa_kernel outside of a torch.compile region. + with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + x, c, y_feat, rope_cos, rope_sin = self.prepare( + x, sigma, y_feat[0], y_mask[0] + ) + del y_mask + + cp_rank, cp_size = get_cp_rank_size() + N = x.size(1) + M = N // cp_size + assert ( + N % cp_size == 0 + ), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})." + + if cp_size > 1: + x = x.narrow(1, cp_rank * M, M) + + assert self.num_heads % cp_size == 0 + local_heads = self.num_heads // cp_size + rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads) + rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads) + + for i, block in enumerate(self.blocks): + x, y_feat = block( + x, + c, + y_feat, + rope_cos=rope_cos, + rope_sin=rope_sin, + packed_indices=packed_indices, + ) # (B, M, D), (B, L, D) + del y_feat # Final layers don't use dense text features. + + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + + patch = x.size(2) + x = all_gather(x) + x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch) + x = rearrange( + x, + "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", + T=T, + hp=H // self.patch_size, + wp=W // self.patch_size, + p1=self.patch_size, + p2=self.patch_size, + c=self.out_channels, + ) + + return x diff --git a/mochi_preview/dit/joint_model/context_parallel.py b/mochi_preview/dit/joint_model/context_parallel.py new file mode 100644 index 0000000..d93145d --- /dev/null +++ b/mochi_preview/dit/joint_model/context_parallel.py @@ -0,0 +1,163 @@ +import torch +import torch.distributed as dist +from einops import rearrange + +_CONTEXT_PARALLEL_GROUP = None +_CONTEXT_PARALLEL_RANK = None +_CONTEXT_PARALLEL_GROUP_SIZE = None +_CONTEXT_PARALLEL_GROUP_RANKS = None + + +def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + return x + + cp_rank, cp_size = get_cp_rank_size() + return x.tensor_split(cp_size, dim=dim)[cp_rank] + + +def set_cp_group(cp_group, ranks, global_rank): + global \ + _CONTEXT_PARALLEL_GROUP, \ + _CONTEXT_PARALLEL_RANK, \ + _CONTEXT_PARALLEL_GROUP_SIZE, \ + _CONTEXT_PARALLEL_GROUP_RANKS + if _CONTEXT_PARALLEL_GROUP is not None: + raise RuntimeError("CP group already initialized.") + _CONTEXT_PARALLEL_GROUP = cp_group + _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group) + _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group) + _CONTEXT_PARALLEL_GROUP_RANKS = ranks + + assert ( + _CONTEXT_PARALLEL_RANK == ranks.index(global_rank) + ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} " + assert _CONTEXT_PARALLEL_GROUP_SIZE == len( + ranks + ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})" + + +def get_cp_group(): + if _CONTEXT_PARALLEL_GROUP is None: + raise RuntimeError("CP group not initialized") + return _CONTEXT_PARALLEL_GROUP + + +def is_cp_active(): + return _CONTEXT_PARALLEL_GROUP is not None + + +def get_cp_rank_size(): + if _CONTEXT_PARALLEL_GROUP: + return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE + else: + return 0, 1 + + +class AllGatherIntoTensorFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup): + ctx.reduce_dtype = reduce_dtype + ctx.group = group + ctx.batch_size = x.size(0) + group_size = dist.get_world_size(group) + + x = x.contiguous() + output = torch.empty( + group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device + ) + dist.all_gather_into_tensor(output, x, group=group) + return output + + +def all_gather(tensor: torch.Tensor) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + return tensor + + return AllGatherIntoTensorFunction.apply( + tensor, torch.float32, _CONTEXT_PARALLEL_GROUP + ) + + +@torch.compiler.disable() +def _all_to_all_single(output, input, group): + # Disable compilation since torch compile changes contiguity. + assert input.is_contiguous(), "Input tensor must be contiguous." + assert output.is_contiguous(), "Output tensor must be contiguous." + return dist.all_to_all_single(output, input, group=group) + + +class CollectTokens(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int): + """Redistribute heads and receive tokens. + + Args: + qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim] + + Returns: + qkv: shape: [3, B, N, local_heads, head_dim] + + where M is the number of local tokens, + N = cp_size * M is the number of global tokens, + local_heads = num_heads // cp_size is the number of local heads. + """ + ctx.group = group + ctx.num_heads = num_heads + cp_size = dist.get_world_size(group) + assert num_heads % cp_size == 0 + ctx.local_heads = num_heads // cp_size + + qkv = rearrange( + qkv, + "B M (qkv G h d) -> G M h B (qkv d)", + qkv=3, + G=cp_size, + h=ctx.local_heads, + ).contiguous() + + output_chunks = torch.empty_like(qkv) + _all_to_all_single(output_chunks, qkv, group=group) + + return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3) + + +def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + # Move QKV dimension to the front. + # B M (3 H d) -> 3 B M H d + B, M, _ = x.size() + x = x.view(B, M, 3, num_heads, -1) + return x.permute(2, 0, 1, 3, 4) + + return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads) + + +class CollectHeads(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup): + """Redistribute tokens and receive heads. + + Args: + x: Output of attention. Shape: [B, N, local_heads, head_dim] + + Returns: + Shape: [B, M, num_heads * head_dim] + """ + ctx.group = group + ctx.local_heads = x.size(2) + ctx.head_dim = x.size(3) + group_size = dist.get_world_size(group) + x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous() + output = torch.empty_like(x) + _all_to_all_single(output, x, group=group) + del x + return rearrange(output, "G h M B D -> B M (G h D)") + + +def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor: + if not _CONTEXT_PARALLEL_GROUP: + # Merge heads. + return x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) + + return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP) diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py new file mode 100644 index 0000000..aa40a67 --- /dev/null +++ b/mochi_preview/dit/joint_model/layers.py @@ -0,0 +1,178 @@ +import collections.abc +import math +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class TimestepEmbedder(nn.Module): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + *, + bias: bool = True, + timestep_scale: Optional[float] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.timestep_scale = timestep_scale + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) + freqs.mul_(-math.log(max_period) / half).exp_() + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + if self.timestep_scale is not None: + t = t * self.timestep_scale + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class PooledCaptionEmbedder(nn.Module): + def __init__( + self, + caption_feature_dim: int, + hidden_size: int, + *, + bias: bool = True, + device: Optional[torch.device] = None, + ): + super().__init__() + self.caption_feature_dim = caption_feature_dim + self.hidden_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=bias, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class FeedForward(nn.Module): + def __init__( + self, + in_features: int, + hidden_size: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + device: Optional[torch.device] = None, + ): + super().__init__() + # keep parameter count and computation constant compared to standard FFN + hidden_size = int(2 * hidden_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_size = int(ffn_dim_multiplier * hidden_size) + hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of) + + self.hidden_dim = hidden_size + self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device) + self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device) + + def forward(self, x): + x, gate = self.w1(x).chunk(2, dim=-1) + x = self.w2(F.silu(x) * gate) + return x + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + bias: bool = True, + dynamic_img_pad: bool = False, + device: Optional[torch.device] = None, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.flatten = flatten + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + device=device, + ) + assert norm_layer is None + self.norm = ( + norm_layer(embed_dim, device=device) if norm_layer else nn.Identity() + ) + + def forward(self, x): + B, _C, T, H, W = x.shape + if not self.dynamic_img_pad: + assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + else: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) + #print("x",x.dtype, x.device) + #print(self.proj.weight.dtype, self.proj.weight.device) + x = self.proj(x) + + # Flatten temporal and spatial dimensions. + if not self.flatten: + raise NotImplementedError("Must flatten output.") + x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) + + x = self.norm(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device)) + self.register_parameter("bias", None) + + def forward(self, x): + x_fp32 = x.float() + x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps) + return (x_normed * self.weight).type_as(x) diff --git a/mochi_preview/dit/joint_model/mod_rmsnorm.py b/mochi_preview/dit/joint_model/mod_rmsnorm.py new file mode 100644 index 0000000..ffbb4c8 --- /dev/null +++ b/mochi_preview/dit/joint_model/mod_rmsnorm.py @@ -0,0 +1,23 @@ +import torch + + +class ModulatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, eps=1e-6): + # Convert to fp32 for precision + x_fp32 = x.float() + scale_fp32 = scale.float() + + # Compute RMS + mean_square = x_fp32.pow(2).mean(-1, keepdim=True) + inv_rms = torch.rsqrt(mean_square + eps) + + # Normalize and modulate + x_normed = x_fp32 * inv_rms + x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1)) + + return x_modulated.type_as(x) + + +def modulated_rmsnorm(x, scale, eps=1e-6): + return ModulatedRMSNorm.apply(x, scale, eps) diff --git a/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py b/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py new file mode 100644 index 0000000..0bb96e2 --- /dev/null +++ b/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py @@ -0,0 +1,27 @@ +import torch + + +class ResidualTanhGatedRMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, x_res, gate, eps=1e-6): + # Convert to fp32 for precision + x_res_fp32 = x_res.float() + + # Compute RMS + mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True) + scale = torch.rsqrt(mean_square + eps) + + # Apply tanh to gate + tanh_gate = torch.tanh(gate).unsqueeze(1) + + # Normalize and apply gated scaling + x_normed = x_res_fp32 * scale * tanh_gate + + # Apply residual connection + output = x + x_normed.type_as(x) + + return output + + +def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): + return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps) diff --git a/mochi_preview/dit/joint_model/rope_mixed.py b/mochi_preview/dit/joint_model/rope_mixed.py new file mode 100644 index 0000000..f2952bd --- /dev/null +++ b/mochi_preview/dit/joint_model/rope_mixed.py @@ -0,0 +1,88 @@ +import functools +import math + +import torch + + +def centers(start: float, stop, num, dtype=None, device=None): + """linspace through bin centers. + + Args: + start (float): Start of the range. + stop (float): End of the range. + num (int): Number of points. + dtype (torch.dtype): Data type of the points. + device (torch.device): Device of the points. + + Returns: + centers (Tensor): Centers of the bins. Shape: (num,). + """ + edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) + return (edges[:-1] + edges[1:]) / 2 + + +@functools.lru_cache(maxsize=1) +def create_position_matrix( + T: int, + pH: int, + pW: int, + device: torch.device, + dtype: torch.dtype, + *, + target_area: float = 36864, +): + """ + Args: + T: int - Temporal dimension + pH: int - Height dimension after patchify + pW: int - Width dimension after patchify + + Returns: + pos: [T * pH * pW, 3] - position matrix + """ + with torch.no_grad(): + # Create 1D tensors for each dimension + t = torch.arange(T, dtype=dtype) + + # Positionally interpolate to area 36864. + # (3072x3072 frame with 16x16 patches = 192x192 latents). + # This automatically scales rope positions when the resolution changes. + # We use a large target area so the model is more sensitive + # to changes in the learned pos_frequencies matrix. + scale = math.sqrt(target_area / (pW * pH)) + w = centers(-pW * scale / 2, pW * scale / 2, pW) + h = centers(-pH * scale / 2, pH * scale / 2, pH) + + # Use meshgrid to create 3D grids + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + # Stack and reshape the grids. + pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] + pos = pos.view(-1, 3) # [T * pH * pW, 3] + pos = pos.to(dtype=dtype, device=device) + + return pos + + +def compute_mixed_rotation( + freqs: torch.Tensor, + pos: torch.Tensor, +): + """ + Project each 3-dim position into per-head, per-head-dim 1D frequencies. + + Args: + freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position + pos: [N, 3] - position of each token + num_heads: int + + Returns: + freqs_cos: [N, num_heads, num_freqs] - cosine components + freqs_sin: [N, num_heads, num_freqs] - sine components + """ + with torch.autocast("cuda", enabled=False): + assert freqs.ndim == 3 + freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) + freqs_cos = torch.cos(freqs_sum) + freqs_sin = torch.sin(freqs_sum) + return freqs_cos, freqs_sin diff --git a/mochi_preview/dit/joint_model/temporal_rope.py b/mochi_preview/dit/joint_model/temporal_rope.py new file mode 100644 index 0000000..a8276db --- /dev/null +++ b/mochi_preview/dit/joint_model/temporal_rope.py @@ -0,0 +1,34 @@ +# Based on Llama3 Implementation. +import torch + + +def apply_rotary_emb_qk_real( + xqk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. + + Args: + xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) + Can be either just query or just key, or both stacked along some batch or * dim. + freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. + freqs_sin (torch.Tensor): Precomputed sine frequency tensor. + + Returns: + torch.Tensor: The input tensor with rotary embeddings applied. + """ + assert xqk.dtype == torch.bfloat16 + # Split the last dimension into even and odd parts + xqk_even = xqk[..., 0::2] + xqk_odd = xqk[..., 1::2] + + # Apply rotation + cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) + sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) + + # Interleave the results back into the original shape + out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) + assert out.dtype == torch.bfloat16 + return out diff --git a/mochi_preview/dit/joint_model/utils.py b/mochi_preview/dit/joint_model/utils.py new file mode 100644 index 0000000..502e3ec --- /dev/null +++ b/mochi_preview/dit/joint_model/utils.py @@ -0,0 +1,189 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + +class AttentionPool(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + output_dim: int = None, + device: Optional[torch.device] = None, + ): + """ + Args: + spatial_dim (int): Number of tokens in sequence length. + embed_dim (int): Dimensionality of input tokens. + num_heads (int): Number of attention heads. + output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. + """ + super().__init__() + self.num_heads = num_heads + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device) + self.to_q = nn.Linear(embed_dim, embed_dim, device=device) + self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device) + + def forward(self, x, mask): + """ + Args: + x (torch.Tensor): (B, L, D) tensor of input tokens. + mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. + + NOTE: We assume x does not require gradients. + + Returns: + x (torch.Tensor): (B, D) tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_heads + kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + +class PadSplitXY(torch.autograd.Function): + """ + Merge heads, pad and extract visual and text tokens, + and split along the sequence length. + """ + + @staticmethod + def forward( + ctx, + xy: torch.Tensor, + indices: torch.Tensor, + B: int, + N: int, + L: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). + indices: Valid token indices out of unpacked tensor. Shape: (total,) + + Returns: + x: Visual tokens. Shape: (B, N, num_heads * head_dim). + y: Text tokens. Shape: (B, L, num_heads * head_dim). + """ + ctx.save_for_backward(indices) + ctx.B, ctx.N, ctx.L = B, N, L + D = xy.size(1) + + # Pad sequences to (B, N + L, dim). + assert indices.ndim == 1 + output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) + indices = indices.unsqueeze(1).expand( + -1, D + ) # (total,) -> (total, num_heads * head_dim) + output.scatter_(0, indices, xy) + xy = output.view(B, N + L, D) + + # Split visual and text tokens along the sequence length. + return torch.tensor_split(xy, (N,), dim=1) + + +def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + return PadSplitXY.apply(xy, indices, B, N, L, dtype) + + +class UnifyStreams(torch.autograd.Function): + """Unify visual and text streams.""" + + @staticmethod + def forward( + ctx, + q_x: torch.Tensor, + k_x: torch.Tensor, + v_x: torch.Tensor, + q_y: torch.Tensor, + k_y: torch.Tensor, + v_y: torch.Tensor, + indices: torch.Tensor, + ): + """ + Args: + q_x: (B, N, num_heads, head_dim) + k_x: (B, N, num_heads, head_dim) + v_x: (B, N, num_heads, head_dim) + q_y: (B, L, num_heads, head_dim) + k_y: (B, L, num_heads, head_dim) + v_y: (B, L, num_heads, head_dim) + indices: (total <= B * (N + L)) + + Returns: + qkv: (total <= B * (N + L), 3, num_heads, head_dim) + """ + ctx.save_for_backward(indices) + B, N, num_heads, head_dim = q_x.size() + ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) + D = num_heads * head_dim + + q = torch.cat([q_x, q_y], dim=1) + k = torch.cat([k_x, k_y], dim=1) + v = torch.cat([v_x, v_y], dim=1) + qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) + + indices = indices[:, None, None].expand(-1, 3, D) + qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) + return qkv.unflatten(2, (num_heads, head_dim)) + + +def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: + return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py new file mode 100644 index 0000000..066998f --- /dev/null +++ b/mochi_preview/t2v_synth_mochi.py @@ -0,0 +1,445 @@ +import json +import os +import random +from functools import partial +from typing import Dict, List + +from safetensors.torch import load_file +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import yaml +from einops import rearrange, repeat +from omegaconf import OmegaConf +from torch import nn +from torch.distributed.fsdp import ( + BackwardPrefetch, + MixedPrecision, + ShardingStrategy, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp.wrap import ( + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) +from transformers import T5EncoderModel, T5Tokenizer +from transformers.models.t5.modeling_t5 import T5Block + +from .dit.joint_model.context_parallel import get_cp_rank_size +from .utils import Timer +from tqdm import tqdm +from comfy.utils import ProgressBar + +T5_MODEL = "weights/T5" +MAX_T5_TOKEN_LENGTH = 256 + +class T5_Tokenizer: + """Wrapper around Hugging Face tokenizer for T5 + + Args: + model_name(str): Name of tokenizer to load. + """ + + def __init__(self): + self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) + + def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): + """ + Args: + prompt (str): The input text to tokenize. + padding (str): The padding strategy. + truncation (bool): Flag indicating whether to truncate the tokens. + return_tensors (str): Flag indicating whether to return tensors. + max_length (int): The max length of the tokens. + """ + assert ( + not max_length or max_length == MAX_T5_TOKEN_LENGTH + ), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5." + + tokenized_output = self.tokenizer( + prompt, + padding=padding, + max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here. + truncation=truncation, + return_tensors=return_tensors, + return_attention_mask=True, + ) + + return tokenized_output + + +def unnormalize_latents( + z: torch.Tensor, + mean: torch.Tensor, + std: torch.Tensor, +) -> torch.Tensor: + """Unnormalize latents. Useful for decoding DiT samples. + + Args: + z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float + + Returns: + torch.Tensor: [B, C_z, T_z, H_z, W_z], float + """ + mean = mean[:, None, None, None] + std = std[:, None, None, None] + + assert z.ndim == 5 + assert z.size(1) == mean.size(0) == std.size(0) + return z * std.to(z) + mean.to(z) + + +def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: + model = FSDP( + model, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ), + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + device_id=device_id, + sync_module_states=True, + use_orig_params=True, + ) + torch.cuda.synchronize() + return model + + +def compute_packed_indices( + N: int, + text_mask: List[torch.Tensor], +) -> Dict[str, torch.Tensor]: + """ + Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 + + Args: + N: Number of visual tokens. + text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. + + Returns: + packed_indices: Dict with keys for Flash Attention: + - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) + in the packed sequence. + - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. + - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. + """ + # Create an expanded token mask saying which tokens are valid across both visual and text tokens. + assert N > 0 and len(text_mask) == 1 + text_mask = text_mask[0] + + mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L) + seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) + valid_token_indices = torch.nonzero( + mask.flatten(), as_tuple=False + ).flatten() # up to (B * (N + L),) + + assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + max_seqlen_in_batch = seqlens_in_batch.max().item() + + return { + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_in_batch_kv": max_seqlen_in_batch, + "valid_token_indices_kv": valid_token_indices, + } + + +def shift_sigma( + sigma: np.ndarray, + shift: float, +): + """Shift noise standard deviation toward higher values. + + Useful for training a model at high resolutions, + or sampling more finely at high noise levels. + + Equivalent to: + sigma_shift = shift / (shift + 1 / sigma - 1) + except for sigma = 0. + + Args: + sigma: noise standard deviation in [0, 1] + shift: shift factor >= 1. + For shift > 1, shifts sigma to higher values. + For shift = 1, identity function. + """ + return shift * sigma / (shift * sigma + 1 - sigma) + + +class T2VSynthMochiModel: + def __init__( + self, + *, + device_id: int, + vae_stats_path: str, + dit_checkpoint_path: str, + ): + super().__init__() + t = Timer() + self.device = torch.device(device_id) + + #self.t5_tokenizer = T5_Tokenizer() + + # with t("load_text_encs"): + # t5_enc = T5EncoderModel.from_pretrained(T5_MODEL) + # self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu") + + with t("construct_dit"): + from .dit.joint_model.asymm_models_joint import ( + AsymmDiTJoint, + ) + model: nn.Module = torch.nn.utils.skip_init( + AsymmDiTJoint, + depth=48, + patch_size=2, + num_heads=24, + hidden_size_x=3072, + hidden_size_y=1536, + mlp_ratio_x=4.0, + mlp_ratio_y=4.0, + in_channels=12, + qk_norm=True, + qkv_bias=False, + out_bias=True, + patch_embed_bias=True, + timestep_mlp_bias=True, + timestep_scale=1000.0, + t5_feat_dim=4096, + t5_token_length=256, + rope_theta=10000.0, + ) + with t("dit_load_checkpoint"): + + model.load_state_dict(load_file(dit_checkpoint_path)) + + with t("fsdp_dit"): + self.dit = model + self.dit.eval() + for name, param in self.dit.named_parameters(): + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) + else: + param.data = param.data.to(torch.bfloat16) + + + vae_stats = json.load(open(vae_stats_path)) + self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) + self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) + + t.print_stats() + + def get_conditioning(self, prompts, *, zero_last_n_prompts: int): + B = len(prompts) + print(f"Getting conditioning for {B} prompts") + assert ( + 0 <= zero_last_n_prompts <= B + ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" + tokenize_kwargs = dict( + prompt=prompts, + padding="max_length", + return_tensors="pt", + truncation=True, + ) + + t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH) + caption_input_ids_t5 = t5_toks["input_ids"] + caption_attention_mask_t5 = t5_toks["attention_mask"].bool() + del t5_toks + + assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) + assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) + + if zero_last_n_prompts > 0: + # Zero the last N prompts + caption_input_ids_t5[-zero_last_n_prompts:] = 0 + caption_attention_mask_t5[-zero_last_n_prompts:] = False + + caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True) + caption_attention_mask_t5 = caption_attention_mask_t5.to( + self.device, non_blocking=True + ) + + y_mask = [caption_attention_mask_t5] + y_feat = [] + + self.t5_enc.to(self.device) + y_feat.append( + self.t5_enc( + caption_input_ids_t5, caption_attention_mask_t5 + ).last_hidden_state.detach().to(torch.float32) + ) + print(y_feat.shape) + print(y_feat[0]) + self.t5_enc.to("cpu") + # Sometimes returns a tensor, othertimes a tuple, not sure why + # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 + assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) + return dict(y_mask=y_mask, y_feat=y_feat) + + def get_packed_indices(self, y_mask, *, lT, lW, lH): + patch_size = 2 + N = lT * lH * lW // (patch_size**2) + assert len(y_mask) == 1 + packed_indices = compute_packed_indices(N, y_mask) + self.move_to_device_(packed_indices) + return packed_indices + + def move_to_device_(self, sample): + if isinstance(sample, dict): + for key in sample.keys(): + if isinstance(sample[key], torch.Tensor): + sample[key] = sample[key].to(self.device, non_blocking=True) + + @torch.inference_mode(mode=True) + def run(self, args, stream_results): + random.seed(args["seed"]) + np.random.seed(args["seed"]) + torch.manual_seed(args["seed"]) + + generator = torch.Generator(device=self.device) + generator.manual_seed(args["seed"]) + + # assert ( + # len(args["prompt"]) == 1 + # ), f"Expected exactly one prompt, got {len(args['prompt'])}" + #prompt = args["prompt"][0] + #neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else "" + B = 1 + + w = args["width"] + h = args["height"] + t = args["num_frames"] + batch_cfg = args["mochi_args"]["batch_cfg"] + sample_steps = args["mochi_args"]["num_inference_steps"] + cfg_schedule = args["mochi_args"].get("cfg_schedule") + assert ( + len(cfg_schedule) == sample_steps + ), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}" + sigma_schedule = args["mochi_args"].get("sigma_schedule") + if sigma_schedule: + assert ( + len(sigma_schedule) == sample_steps + 1 + ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" + assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}" + + # if batch_cfg: + # sample_batched = self.get_conditioning( + # [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0 + # ) + # else: + # sample = self.get_conditioning([prompt], zero_last_n_prompts=0) + # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) + + spatial_downsample = 8 + temporal_downsample = 6 + latent_t = (t - 1) // temporal_downsample + 1 + latent_w, latent_h = w // spatial_downsample, h // spatial_downsample + + latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h) + in_channels = 12 + z = torch.randn( + (B, in_channels, latent_t, latent_h, latent_w), + device=self.device, + generator=generator, + dtype=torch.float32, + ) + + # if batch_cfg: + # sample_batched["packed_indices"] = self.get_packed_indices( + # sample_batched["y_mask"], **latent_dims + # ) + # z = repeat(z, "b ... -> (repeat b) ...", repeat=2) + # else: + + sample = { + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + } + sample_null = { + "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] + } + + # print(sample["y_mask"]) + # print(type(sample["y_mask"])) + # print(sample["y_mask"][0].shape) + + # print(sample["y_feat"]) + # print(type(sample["y_feat"])) + # print(sample["y_feat"][0].shape) + + print(sample_null["y_mask"]) + print(type(sample_null["y_mask"])) + print(sample_null["y_mask"][0].shape) + + print(sample_null["y_feat"]) + print(type(sample_null["y_feat"])) + print(sample_null["y_feat"][0].shape) + + sample["packed_indices"] = self.get_packed_indices( + sample["y_mask"], **latent_dims + ) + sample_null["packed_indices"] = self.get_packed_indices( + sample_null["y_mask"], **latent_dims + ) + + def model_fn(*, z, sigma, cfg_scale): + #print("z", z.dtype, z.device) + #print("sigma", sigma.dtype, sigma.device) + self.dit.to(self.device) + # if batch_cfg: + # with torch.autocast("cuda", dtype=torch.bfloat16): + # out = self.dit(z, sigma, **sample_batched) + # out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) + #else: + + nonlocal sample, sample_null + with torch.autocast("cuda", dtype=torch.bfloat16): + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + assert out_cond.shape == out_uncond.shape + + return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond + + comfy_pbar = ProgressBar(sample_steps) + for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): + sigma = sigma_schedule[i] + dsigma = sigma - sigma_schedule[i + 1] + + # `pred` estimates `z_0 - eps`. + pred, output_cond = model_fn( + z=z, + sigma=torch.full( + [B] if not batch_cfg else [B * 2], sigma, device=z.device + ), + cfg_scale=cfg_schedule[i], + ) + pred = pred.to(z) + output_cond = output_cond.to(z) + + #if stream_results: + # yield i / sample_steps, None, False + z = z + dsigma * pred + comfy_pbar.update(1) + + cp_rank, cp_size = get_cp_rank_size() + if batch_cfg: + z = z[:B] + z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim + self.dit.to("cpu") + + samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) + print("samples", samples.shape, samples.dtype, samples.device) + return samples diff --git a/mochi_preview/utils.py b/mochi_preview/utils.py new file mode 100644 index 0000000..8732472 --- /dev/null +++ b/mochi_preview/utils.py @@ -0,0 +1,33 @@ +import time + + +class Timer: + def __init__(self): + self.times = {} # Dictionary to store times per stage + + def __call__(self, name): + print(f"Timing {name}") + return self.TimerContextManager(self, name) + + def print_stats(self): + total_time = sum(self.times.values()) + # Print table header + print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent")) + for name, t in self.times.items(): + percent = (t / total_time) * 100 if total_time > 0 else 0 + print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent)) + + class TimerContextManager: + def __init__(self, outer, name): + self.outer = outer # Reference to the Timer instance + self.name = name + self.start_time = None + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, traceback): + end_time = time.perf_counter() + elapsed = end_time - self.start_time + self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed diff --git a/mochi_preview/vae/__init__.py b/mochi_preview/vae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mochi_preview/vae/__pycache__/__init__.cpython-311.pyc b/mochi_preview/vae/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..8c3e80b Binary files /dev/null and b/mochi_preview/vae/__pycache__/__init__.cpython-311.pyc differ diff --git a/mochi_preview/vae/__pycache__/__init__.cpython-312.pyc b/mochi_preview/vae/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..b4aadf5 Binary files /dev/null and b/mochi_preview/vae/__pycache__/__init__.cpython-312.pyc differ diff --git a/mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc b/mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc new file mode 100644 index 0000000..0254bf9 Binary files /dev/null and b/mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc differ diff --git a/mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc b/mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc new file mode 100644 index 0000000..b4105c7 Binary files /dev/null and b/mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc differ diff --git a/mochi_preview/vae/__pycache__/model.cpython-311.pyc b/mochi_preview/vae/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000..fcead8c Binary files /dev/null and b/mochi_preview/vae/__pycache__/model.cpython-311.pyc differ diff --git a/mochi_preview/vae/__pycache__/model.cpython-312.pyc b/mochi_preview/vae/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..b4fb4a5 Binary files /dev/null and b/mochi_preview/vae/__pycache__/model.cpython-312.pyc differ diff --git a/mochi_preview/vae/cp_conv.py b/mochi_preview/vae/cp_conv.py new file mode 100644 index 0000000..e5e96de --- /dev/null +++ b/mochi_preview/vae/cp_conv.py @@ -0,0 +1,152 @@ +from typing import Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor: + """ + Forward pass that handles communication between ranks for inference. + Args: + x: Tensor of shape (B, C, T, H, W) + frames_to_send: int, number of frames to communicate between ranks + Returns: + output: Tensor of shape (B, C, T', H, W) + """ + cp_rank, cp_world_size = cp.get_cp_rank_size() + if frames_to_send == 0 or cp_world_size == 1: + return x + + group = get_cp_group() + global_rank = dist.get_rank() + + # Send to next rank + if cp_rank < cp_world_size - 1: + assert x.size(2) >= frames_to_send + tail = x[:, :, -frames_to_send:].contiguous() + dist.send(tail, global_rank + 1, group=group) + + # Receive from previous rank + if cp_rank > 0: + B, C, _, H, W = x.shape + recv_buffer = torch.empty( + (B, C, frames_to_send, H, W), + dtype=x.dtype, + device=x.device, + ) + dist.recv(recv_buffer, global_rank - 1, group=group) + x = torch.cat([recv_buffer, x], dim=2) + + return x + + +def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor: + if max_T > x.size(2): + pad_T = max_T - x.size(2) + pad_dims = (0, 0, 0, 0, 0, pad_T) + return F.pad(x, pad_dims) + return x + + +def gather_all_frames(x: torch.Tensor) -> torch.Tensor: + """ + Gathers all frames from all processes for inference. + Args: + x: Tensor of shape (B, C, T, H, W) + Returns: + output: Tensor of shape (B, C, T_total, H, W) + """ + cp_rank, cp_size = get_cp_rank_size() + cp_group = get_cp_group() + + # Ensure the tensor is contiguous for collective operations + x = x.contiguous() + + # Get the local time dimension size + local_T = x.size(2) + local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64) + + # Gather all T sizes from all processes + all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)] + dist.all_gather(all_T, local_T_tensor, group=cp_group) + all_T = [t.item() for t in all_T] + + # Pad the tensor at the end of the time dimension to match max_T + max_T = max(all_T) + x = _pad_to_max(x, max_T).contiguous() + + # Prepare a list to hold the gathered tensors + gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)] + + # Perform the all_gather operation + dist.all_gather(gathered_x, x, group=cp_group) + + # Slice each gathered tensor back to its original T size + for idx, t_size in enumerate(all_T): + gathered_x[idx] = gathered_x[idx][:, :, :t_size] + + return torch.cat(gathered_x, dim=2) + + +def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool: + """Estimate memory usage based on input tensor size and data type.""" + element_size = input.element_size() # Size in bytes of each element + memory_bytes = input.numel() * element_size + memory_gb = memory_bytes / 1024**3 + return memory_gb > max_gb + + +class ContextParallelCausalConv3d(torch.nn.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + **kwargs, + ): + kernel_size = cast_tuple(kernel_size, 3) + stride = cast_tuple(stride, 3) + height_pad = (kernel_size[1] - 1) // 2 + width_pad = (kernel_size[2] - 1) // 2 + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + **kwargs, + ) + + def forward(self, x: torch.Tensor): + cp_rank, cp_world_size = get_cp_rank_size() + + context_size = self.kernel_size[0] - 1 + if cp_rank == 0: + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode) + + if cp_world_size == 1: + return super().forward(x) + + if all(s == 1 for s in self.stride): + # Receive some frames from previous rank. + x = cp_pass_frames(x, context_size) + return super().forward(x) + + # Less efficient implementation for strided convs. + # All gather x, infer and chunk. + x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] + x = super().forward(x) + x_chunks = x.tensor_split(cp_world_size, dim=2) + assert len(x_chunks) == cp_world_size + return x_chunks[cp_rank] diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py new file mode 100644 index 0000000..1263271 --- /dev/null +++ b/mochi_preview/vae/model.py @@ -0,0 +1,815 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard +from ..vae.cp_conv import cp_pass_frames, gather_all_frames + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +class GroupNormSpatial(nn.GroupNorm): + """ + GroupNorm applied per-frame. + """ + + def forward(self, x: torch.Tensor, *, chunk_size: int = 8): + B, C, T, H, W = x.shape + x = rearrange(x, "B C T H W -> (B T) C H W") + # Run group norm in chunks. + output = torch.empty_like(x) + for b in range(0, B * T, chunk_size): + output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) + return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) + + +class SafeConv3d(torch.nn.Conv3d): + """ + NOTE: No support for padding along time dimension. + Input must already be padded along time. + """ + + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + part_num = int(memory_count / 2) + 1 + k = self.kernel_size[0] + input_idx = torch.arange(k - 1, input.size(2)) + input_chunks_idx = torch.chunk(input_idx, part_num, dim=0) + + # assert self.kernel_size == (3, 3, 3), f"kernel_size {self.kernel_size} != (3, 3, 3)" + assert self.stride[0] == 1, f"stride {self.stride}" + assert self.dilation[0] == 1, f"dilation {self.dilation}" + assert self.padding[0] == 0, f"padding {self.padding}" + + # Comptue output size + assert not input.requires_grad + B, _, T_in, H_in, W_in = input.shape + output_size = ( + B, + self.out_channels, + T_in - k + 1, + H_in // self.stride[1], + W_in // self.stride[2], + ) + output = torch.empty(output_size, dtype=input.dtype, device=input.device) + for input_chunk_idx in input_chunks_idx: + input_s = input_chunk_idx[0] - k + 1 + input_e = input_chunk_idx[-1] + 1 + input_chunk = input[:, :, input_s:input_e, :, :] + output_chunk = super(SafeConv3d, self).forward(input_chunk) + + output_s = input_s + output_e = output_s + output_chunk.size(2) + output[:, :, output_s:output_e, :, :] = output_chunk + + return output + else: + return super(SafeConv3d, self).forward(input) + + +class StridedSafeConv3d(torch.nn.Conv3d): + def forward(self, input, local_shard: bool = False): + assert self.stride[0] == self.kernel_size[0] + assert self.dilation[0] == 1 + assert self.padding[0] == 0 + + kernel_size = self.kernel_size[0] + stride = self.stride[0] + T_in = input.size(2) + T_out = T_in // kernel_size + + # Parallel implementation. + if local_shard: + idx = torch.arange(T_out) + idx = local_shard(idx, dim=0) + start = idx.min() * stride + end = idx.max() * stride + kernel_size + local_input = input[:, :, start:end, :, :] + return torch.nn.Conv3d.forward(self, local_input) + + raise NotImplementedError + + +class ContextParallelConv3d(SafeConv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + causal: bool = True, + context_parallel: bool = True, + **kwargs, + ): + self.causal = causal + self.context_parallel = context_parallel + kernel_size = cast_tuple(kernel_size, 3) + stride = cast_tuple(stride, 3) + height_pad = (kernel_size[1] - 1) // 2 + width_pad = (kernel_size[2] - 1) // 2 + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=(1, 1, 1), + padding=(0, height_pad, width_pad), + **kwargs, + ) + + def forward(self, x: torch.Tensor): + cp_rank, cp_world_size = get_cp_rank_size() + + # Compute padding amounts. + context_size = self.kernel_size[0] - 1 + if self.causal: + pad_front = context_size + pad_back = 0 + else: + pad_front = context_size // 2 + pad_back = context_size - pad_front + + # Apply padding. + assert self.padding_mode == "replicate" # DEBUG + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + if self.context_parallel and cp_world_size == 1: + x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + else: + if cp_rank == 0: + x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + elif cp_rank == cp_world_size - 1 and pad_back: + x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode) + + if self.context_parallel and cp_world_size == 1: + return super().forward(x) + + if self.stride[0] == 1: + # Receive some frames from previous rank. + x = cp_pass_frames(x, context_size) + return super().forward(x) + + # Less efficient implementation for strided convs. + # All gather x, infer and chunk. + assert ( + x.dtype == torch.bfloat16 + ), f"Expected x to be of type torch.bfloat16, got {x.dtype}" + + x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] + return StridedSafeConv3d.forward(self, x, local_shard=True) + + +class Conv1x1(nn.Linear): + """*1x1 Conv implemented with a linear layer.""" + + def __init__(self, in_features: int, out_features: int, *args, **kwargs): + super().__init__(in_features, out_features, *args, **kwargs) + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, *] or [B, *, C]. + + Returns: + x: Output tensor. Shape: [B, C', *] or [B, *, C']. + """ + x = x.movedim(1, -1) + x = super().forward(x) + x = x.movedim(-1, 1) + return x + + +class DepthToSpaceTime(nn.Module): + def __init__( + self, + temporal_expansion: int, + spatial_expansion: int, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # When printed, this module should show the temporal and spatial expansion factors. + def extra_repr(self): + return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}" + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + + Returns: + x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s]. + """ + x = rearrange( + x, + "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)", + st=self.temporal_expansion, + sh=self.spatial_expansion, + sw=self.spatial_expansion, + ) + + cp_rank, _ = get_cp_rank_size() + if self.temporal_expansion > 1 and cp_rank == 0: + # Drop the first self.temporal_expansion - 1 frames. + # This is because we always want the 3x3x3 conv filter to only apply + # to the first frame, and the first frame doesn't need to be repeated. + assert all(x.shape) + x = x[:, :, self.temporal_expansion - 1 :] + assert all(x.shape) + + return x + + +def norm_fn( + in_channels: int, + affine: bool = True, +): + return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) + + +class ResBlock(nn.Module): + """Residual block that preserves the spatial dimensions.""" + + def __init__( + self, + channels: int, + *, + affine: bool = True, + attn_block: Optional[nn.Module] = None, + padding_mode: str = "replicate", + causal: bool = True, + ): + super().__init__() + self.channels = channels + + assert causal + self.stack = nn.Sequential( + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + ContextParallelConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + causal=causal, + ), + norm_fn(channels, affine=affine), + nn.SiLU(inplace=True), + ContextParallelConv3d( + in_channels=channels, + out_channels=channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding_mode=padding_mode, + bias=True, + causal=causal, + ), + ) + + self.attn_block = attn_block if attn_block else nn.Identity() + + def forward(self, x: torch.Tensor): + """Forward pass. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + """ + residual = x + x = self.stack(x) + x = x + residual + del residual + + return self.attn_block(x) + + +def prepare_for_attention(qkv: torch.Tensor, head_dim: int, qk_norm: bool = True): + """Prepare qkv tensor for attention and normalize qk. + + Args: + qkv: Input tensor. Shape: [B, L, 3 * num_heads * head_dim]. + + Returns: + q, k, v: qkv tensor split into q, k, v. Shape: [B, num_heads, L, head_dim]. + """ + assert qkv.ndim == 3 # [B, L, 3 * num_heads * head_dim] + assert qkv.size(2) % (3 * head_dim) == 0 + num_heads = qkv.size(2) // (3 * head_dim) + qkv = qkv.unflatten(2, (3, num_heads, head_dim)) + + q, k, v = qkv.unbind(2) # [B, L, num_heads, head_dim] + q = q.transpose(1, 2) # [B, num_heads, L, head_dim] + k = k.transpose(1, 2) # [B, num_heads, L, head_dim] + v = v.transpose(1, 2) # [B, num_heads, L, head_dim] + + if qk_norm: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + # Mixed precision can change the dtype of normed q/k to float32. + q = q.to(dtype=qkv.dtype) + k = k.to(dtype=qkv.dtype) + + return q, k, v + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + out_bias: bool = True, + qk_norm: bool = True, + ) -> None: + super().__init__() + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.qk_norm = qk_norm + + self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) + self.out = nn.Linear(dim, dim, bias=out_bias) + + def forward( + self, + x: torch.Tensor, + *, + chunk_size=2**15, + ) -> torch.Tensor: + """Compute temporal self-attention. + + Args: + x: Input tensor. Shape: [B, C, T, H, W]. + chunk_size: Chunk size for large tensors. + + Returns: + x: Output tensor. Shape: [B, C, T, H, W]. + """ + B, _, T, H, W = x.shape + + if T == 1: + # No attention for single frame. + x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C] + qkv = self.qkv(x) + _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys. + x = self.out(x) + return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W] + + # 1D temporal attention. + x = rearrange(x, "B C t h w -> (B h w) t C") + qkv = self.qkv(x) + + # Input: qkv with shape [B, t, 3 * num_heads * head_dim] + # Output: x with shape [B, num_heads, t, head_dim] + q, k, v = prepare_for_attention(qkv, self.head_dim, qk_norm=self.qk_norm) + + attn_kwargs = dict( + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=self.head_dim**-0.5, + ) + + if q.size(0) <= chunk_size: + x = F.scaled_dot_product_attention( + q, k, v, **attn_kwargs + ) # [B, num_heads, t, head_dim] + else: + # Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.` + # Chunks of 2**16 and up cause an error. + x = torch.empty_like(q) + for i in range(0, q.size(0), chunk_size): + qc = q[i : i + chunk_size] + kc = k[i : i + chunk_size] + vc = v[i : i + chunk_size] + chunk = F.scaled_dot_product_attention(qc, kc, vc, **attn_kwargs) + x[i : i + chunk_size].copy_(chunk) + + assert x.size(0) == q.size(0) + x = x.transpose(1, 2) # [B, t, num_heads, head_dim] + x = x.flatten(2) # [B, t, num_heads * head_dim] + + x = self.out(x) + x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + **attn_kwargs, + ) -> None: + super().__init__() + self.norm = norm_fn(dim) + self.attn = Attention(dim, **attn_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.attn(self.norm(x)) + + +class CausalUpsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + *, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + **block_kwargs, + ): + super().__init__() + + blocks = [] + for _ in range(num_res_blocks): + blocks.append(block_fn(in_channels, **block_kwargs)) + self.blocks = nn.Sequential(*blocks) + + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + # Change channels in the final convolution layer. + self.proj = Conv1x1( + in_channels, + out_channels * temporal_expansion * (spatial_expansion**2), + ) + + self.d2st = DepthToSpaceTime( + temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion + ) + + def forward(self, x): + x = self.blocks(x) + x = self.proj(x) + x = self.d2st(x) + return x + + +def block_fn(channels, *, has_attention: bool = False, **block_kwargs): + attn_block = AttentionBlock(channels) if has_attention else None + + return ResBlock( + channels, affine=True, attn_block=attn_block, **block_kwargs + ) + + +class DownsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks, + *, + temporal_reduction=2, + spatial_reduction=2, + **block_kwargs, + ): + """ + Downsample block for the VAE encoder. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks. + temporal_reduction: Temporal reduction factor. + spatial_reduction: Spatial reduction factor. + """ + super().__init__() + layers = [] + + # Change the channel count in the strided convolution. + # This lets the ResBlock have uniform channel count, + # as in ConvNeXt. + assert in_channels != out_channels + layers.append( + ContextParallelConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), + stride=(temporal_reduction, spatial_reduction, spatial_reduction), + padding_mode="replicate", + bias=True, + ) + ) + + for _ in range(num_res_blocks): + layers.append(block_fn(out_channels, **block_kwargs)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): + num_freqs = (stop - start) // step + assert inputs.ndim == 5 + C = inputs.size(1) + + # Create Base 2 Fourier features. + freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device) + assert num_freqs == len(freqs) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + C = inputs.shape[1] + w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w. + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat( + [ + inputs, + torch.sin(h), + torch.cos(h), + ], + dim=1, + ) + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs): + """Add Fourier features to inputs. + + Args: + inputs: Input tensor. Shape: [B, C, T, H, W] + + Returns: + h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W] + """ + return add_fourier_features(inputs, self.start, self.stop, self.step) + + +class Decoder(nn.Module): + def __init__( + self, + *, + out_channels: int = 3, + latent_dim: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + temporal_expansions: Optional[List[int]] = None, + spatial_expansions: Optional[List[int]] = None, + has_attention: List[bool], + output_norm: bool = True, + nonlinearity: str = "silu", + output_nonlinearity: str = "silu", + causal: bool = True, + **block_kwargs, + ): + super().__init__() + self.input_channels = latent_dim + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.output_nonlinearity = output_nonlinearity + assert nonlinearity == "silu" + assert causal + + ch = [mult * base_channels for mult in channel_multipliers] + self.num_up_blocks = len(ch) - 1 + assert len(num_res_blocks) == self.num_up_blocks + 2 + + blocks = [] + + first_block = [ + nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) + ] # Input layer. + # First set of blocks preserve channel count. + for _ in range(num_res_blocks[-1]): + first_block.append( + block_fn( + ch[-1], + has_attention=has_attention[-1], + causal=causal, + **block_kwargs, + ) + ) + blocks.append(nn.Sequential(*first_block)) + + assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks + assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2 + + upsample_block_fn = CausalUpsampleBlock + + for i in range(self.num_up_blocks): + block = upsample_block_fn( + ch[-i - 1], + ch[-i - 2], + num_res_blocks=num_res_blocks[-i - 2], + has_attention=has_attention[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + causal=causal, + **block_kwargs, + ) + blocks.append(block) + + assert not output_norm + + # Last block. Preserve channel count. + last_block = [] + for _ in range(num_res_blocks[0]): + last_block.append( + block_fn( + ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs + ) + ) + blocks.append(nn.Sequential(*last_block)) + + self.blocks = nn.ModuleList(blocks) + self.output_proj = Conv1x1(ch[0], out_channels) + + def forward(self, x): + """Forward pass. + + Args: + x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1]. + + Returns: + x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]. + T + 1 = (t - 1) * 4. + H = h * 16, W = w * 16. + """ + for block in self.blocks: + x = block(x) + + if self.output_nonlinearity == "silu": + x = F.silu(x, inplace=not self.training) + else: + assert ( + not self.output_nonlinearity + ) # StyleGAN3 omits the to-RGB nonlinearity. + + return self.output_proj(x).contiguous() + + +def make_broadcastable( + tensor: torch.Tensor, + axis: int, + ndim: int, +) -> torch.Tensor: + """ + Reshapes the input tensor to have singleton dimensions in all axes except the specified axis. + + Args: + tensor (torch.Tensor): The tensor to reshape. Typically 1D. + axis (int): The axis along which the tensor should retain its original size. + ndim (int): The total number of dimensions the reshaped tensor should have. + + Returns: + torch.Tensor: The reshaped tensor with shape suitable for broadcasting. + """ + if tensor.dim() != 1: + raise ValueError(f"Expected tensor to be 1D, but got {tensor.dim()}D tensor.") + + axis = (axis + ndim) % ndim # Ensure the axis is within the tensor dimensions + shape = [1] * ndim # Start with all dimensions as 1 + shape[axis] = tensor.size(0) # Set the specified axis to the size of the tensor + return tensor.view(*shape) + + +def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor: + """ + Blends two tensors `a` and `b` along the specified axis using linear interpolation. + + Args: + a (torch.Tensor): The first tensor. + b (torch.Tensor): The second tensor. Must have the same shape as `a`. + axis (int): The axis along which to perform the blending. + + Returns: + torch.Tensor: The blended tensor. + """ + assert ( + a.shape == b.shape + ), f"Tensors must have the same shape, got {a.shape} and {b.shape}" + steps = a.size(axis) + + # Create a weight tensor that linearly interpolates from 0 to 1 + start = 1 / (steps + 1) + end = steps / (steps + 1) + weight = torch.linspace(start, end, steps=steps, device=a.device, dtype=a.dtype) + + # Make the weight tensor broadcastable across all dimensions + weight = make_broadcastable(weight, axis, a.dim()) + + # Perform the blending + return a * (1 - weight) + b * weight + + +def blend_horizontal(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor: + if overlap == 0: + return torch.cat([a, b], dim=-1) + + assert a.size(-1) >= overlap + assert b.size(-1) >= overlap + a_left, a_overlap = a[..., :-overlap], a[..., -overlap:] + b_overlap, b_right = b[..., :overlap], b[..., overlap:] + return torch.cat([a_left, blend(a_overlap, b_overlap, -1), b_right], dim=-1) + + +def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor: + if overlap == 0: + return torch.cat([a, b], dim=-2) + + assert a.size(-2) >= overlap + assert b.size(-2) >= overlap + a_top, a_overlap = a[..., :-overlap, :], a[..., -overlap:, :] + b_overlap, b_bottom = b[..., :overlap, :], b[..., overlap:, :] + return torch.cat([a_top, blend(a_overlap, b_overlap, -2), b_bottom], dim=-2) + + +def nearest_multiple(x: int, multiple: int) -> int: + return round(x / multiple) * multiple + + +def apply_tiled( + fn: Callable[[torch.Tensor], torch.Tensor], + x: torch.Tensor, + num_tiles_w: int, + num_tiles_h: int, + overlap: int = 0, # Number of pixel of overlap between adjacent tiles. + # Use a factor of 2 times the latent downsample factor. + min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing. +): + if num_tiles_w == 1 and num_tiles_h == 1: + return fn(x) + + assert ( + num_tiles_w & (num_tiles_w - 1) == 0 + ), f"num_tiles_w={num_tiles_w} must be a power of 2" + assert ( + num_tiles_h & (num_tiles_h - 1) == 0 + ), f"num_tiles_h={num_tiles_h} must be a power of 2" + + H, W = x.shape[-2:] + assert H % min_block_size == 0 + assert W % min_block_size == 0 + ov = overlap // 2 + assert ov % min_block_size == 0 + + if num_tiles_w >= 2: + # Subdivide horizontally. + half_W = nearest_multiple(W // 2, min_block_size) + left = x[..., :, : half_W + ov] + right = x[..., :, half_W - ov :] + + assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even" + left = apply_tiled( + fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + ) + right = apply_tiled( + fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size + ) + if left is None or right is None: + return None + + # If `fn` changed the resolution, adjust the overlap. + resample_factor = left.size(-1) / (half_W + ov) + out_overlap = int(overlap * resample_factor) + + return blend_horizontal(left, right, out_overlap) + + if num_tiles_h >= 2: + # Subdivide vertically. + half_H = nearest_multiple(H // 2, min_block_size) + top = x[..., : half_H + ov, :] + bottom = x[..., half_H - ov :, :] + + assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even" + top = apply_tiled( + fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + ) + bottom = apply_tiled( + fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size + ) + if top is None or bottom is None: + return None + + # If `fn` changed the resolution, adjust the overlap. + resample_factor = top.size(-2) / (half_H + ov) + out_overlap = int(overlap * resample_factor) + + return blend_vertical(top, bottom, out_overlap) + + raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}") diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000..e6c5459 --- /dev/null +++ b/nodes.py @@ -0,0 +1,356 @@ +import os +import torch +import torch.nn as nn +import folder_paths +import comfy.model_management as mm +from comfy.utils import ProgressBar, load_torch_file +from einops import rearrange + +from contextlib import nullcontext + +from PIL import Image +import numpy as np +import json + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel +from .mochi_preview.vae.model import Decoder + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2) + const = quadratic_coef * (linear_steps ** 2) + quadratic_sigma_schedule = [ + quadratic_coef * (i ** 2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + +class DownloadAndLoadMochiModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ( + [ + "mochi_preview_dit_fp8_e4m3fn.safetensors", + ], + ), + "vae": ( + [ + "mochi_preview_vae_bf16.safetensors", + ], + ), + "precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"], + {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} + ), + }, + } + + RETURN_TYPES = ("MOCHIMODEL", "MOCHIVAE",) + RETURN_NAMES = ("mochi_model", "mochi_vae" ) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" + + def loadmodel(self, model, vae, precision): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + + # Transformer model + model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi') + model_path = os.path.join(model_download_path, model) + + repo_id = "kijai/Mochi_preview_comfy" + + if not os.path.exists(model_path): + log.info(f"Downloading mochi model to: {model_path}") + from huggingface_hub import snapshot_download + snapshot_download( + repo_id=repo_id, + allow_patterns=[f"*{model}*"], + local_dir=model_download_path, + local_dir_use_symlinks=False, + ) + # VAE + vae_download_path = os.path.join(folder_paths.models_dir, 'vae', 'mochi') + vae_path = os.path.join(vae_download_path, vae) + + if not os.path.exists(vae_path): + log.info(f"Downloading mochi VAE to: {vae_path}") + from huggingface_hub import snapshot_download + snapshot_download( + repo_id=repo_id, + allow_patterns=[f"*{vae}*"], + local_dir=model_download_path, + local_dir_use_symlinks=False, + ) + + model = T2VSynthMochiModel( + device_id=0, + vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), + dit_checkpoint_path=model_path, + ) + vae = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + padding_mode="replicate", + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + decoder_sd = load_torch_file(vae_path) + vae.load_state_dict(decoder_sd, strict=True) + vae.eval().to(torch.bfloat16).to("cpu") + del decoder_sd + + return (model, vae,) + +class MochiTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP",), + "prompt": ("STRING", {"default": "", "multiline": True} ), + }, + "optional": { + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "force_offload": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("CONDITIONING",) + RETURN_NAMES = ("conditioning",) + FUNCTION = "process" + CATEGORY = "MochiWrapper" + + def process(self, clip, prompt, strength=1.0, force_offload=True): + max_tokens = 256 + load_device = mm.text_encoder_device() + offload_device = mm.text_encoder_offload_device() + #print(clip.tokenizer.t5xxl) + clip.tokenizer.t5xxl.pad_to_max_length = True + clip.tokenizer.t5xxl.max_length = max_tokens + clip.cond_stage_model.t5xxl.return_attention_masks = True + clip.cond_stage_model.t5xxl.enable_attention_masks = True + clip.cond_stage_model.t5_attention_mask = True + clip.cond_stage_model.to(load_device) + tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True) + + embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens) + + + if embeds.shape[1] > 256: + raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") + embeds *= strength + if force_offload: + clip.cond_stage_model.to(offload_device) + + t5_embeds = { + "embeds": embeds, + "attention_mask": attention_mask["attention_mask"].bool(), + } + return (t5_embeds, ) + + +class MochiSampler: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MOCHIMODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "width": ("INT", {"default": 848, "min": 128, "max": 2048, "step": 8}), + "height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}), + "num_frames": ("INT", {"default": 49, "min": 7, "max": 1024, "step": 6}), + "steps": ("INT", {"default": 50, "min": 2}), + "cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + } + + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("model", "samples",) + FUNCTION = "process" + CATEGORY = "MochiWrapper" + + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames): + mm.soft_empty_cache() + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + args = { + "height": height, + "width": width, + "num_frames": num_frames, + "mochi_args": { + "sigma_schedule": linear_quadratic_schedule(steps, 0.025), + "cfg_schedule": [cfg] * steps, + "num_inference_steps": steps, + "batch_cfg": False, + }, + "positive_embeds": positive, + "negative_embeds": negative, + "seed": seed, + } + latents = model.run(args, stream_results=False) + + mm.soft_empty_cache() + + return ({"samples": latents},) + +class MochiDecode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "vae": ("MOCHIVAE",), + "samples": ("LATENT", ), + "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), + }, + "optional": { + "tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}), + "tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}), + "tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}), + "tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}), + "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "decode" + CATEGORY = "MochiWrapper" + + def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + samples = samples["samples"] + + def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6 + self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5 + #7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199 + self.num_latent_frames_batch_size = 6 + + self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8 + self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8 + + self.tile_latent_min_height = int(self.tile_sample_min_height / 8) + self.tile_latent_min_width = int(self.tile_sample_min_width / 8) + + vae.to(device) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + if not enable_vae_tiling: + samples = vae(samples) + else: + batch_size, num_channels, num_frames, height, width = samples.shape + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = samples[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + tile = vae(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + samples = torch.cat(result_rows, dim=3) + vae.to(offload_device) + #print("samples", samples.shape, samples.dtype, samples.device) + + samples = samples.float() + samples = (samples + 1.0) / 2.0 + samples.clamp_(0.0, 1.0) + + frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float() + #print(frames.shape) + + return (frames,) + + +NODE_CLASS_MAPPINGS = { + "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, + "MochiSampler": MochiSampler, + "MochiDecode": MochiDecode, + "MochiTextEncode": MochiTextEncode, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DownloadAndLoadMochiModel": "(Down)load Mochi Model", + "MochiSampler": "Mochi Sampler", + "MochiDecode": "Mochi Decode", + "MochiTextEncode": "Mochi TextEncode", + } diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..549eee6 --- /dev/null +++ b/readme.md @@ -0,0 +1,18 @@ +# ComfyUI wrapper nodes for Mochi video gen: https://github.com/genmoai/models + + +# WORK IN PROGRESS + +## Requires flash_attn ! + +Depending on frame count can fit under 20GB, VAE decoding is heavy and there is experimental tiled decoder (taken from CogVideoX -diffusers code) which allows higher frame counts, so far highest I've done is 97 with the default tile size 2x2 grid. + +Models: + +https://huggingface.co/Kijai/Mochi_preview_comfy/tree/main + +model to: `ComfyUI/models/diffusion_models/mochi` + +vae to: `ComfyUI/models/vae/mochi` + +There is autodownload node (also will be normal loader node) \ No newline at end of file