initial commit

This commit is contained in:
kijai 2024-10-23 15:34:22 +03:00
parent 6fa487d3b9
commit b80cb4a691
52 changed files with 3416 additions and 0 deletions

3
__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]

Binary file not shown.

Binary file not shown.

4
configs/vae_stats.json Normal file
View File

@ -0,0 +1,4 @@
{
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
}

213
infer.py Normal file
View File

@ -0,0 +1,213 @@
import json
import os
import tempfile
import time
import click
import numpy as np
#import ray
from einops import rearrange
from PIL import Image
from tqdm import tqdm
from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
model = None
model_path = "weights"
def noexcept(f):
try:
return f()
except:
pass
# class MochiWrapper:
# def __init__(self, *, num_workers, **actor_kwargs):
# super().__init__()
# RemoteClass = ray.remote(T2VSynthMochiModel)
# self.workers = [
# RemoteClass.options(num_gpus=1).remote(
# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs
# )
# for i in range(num_workers)
# ]
# # Ensure the __init__ method has finished on all workers
# for worker in self.workers:
# ray.get(worker.__ray_ready__.remote())
# self.is_loaded = True
# def __call__(self, args):
# work_refs = [
# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers)
# ]
# try:
# for result in work_refs[0]:
# yield ray.get(result)
# # Handle the (very unlikely) edge-case where a worker that's not the 1st one
# # fails (don't want an uncaught error)
# for result in work_refs[1:]:
# ray.get(result)
# except Exception as e:
# # Get exception from other workers
# for ref in work_refs[1:]:
# noexcept(lambda: ray.get(ref))
# raise e
def set_model_path(path):
global model_path
model_path = path
def load_model():
global model, model_path
if model is None:
#ray.init()
MOCHI_DIR = model_path
VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors"
MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml"
MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors"
model = T2VSynthMochiModel(
device_id=0,
world_size=1,
local_rank=0,
vae_stats_path=f"{MOCHI_DIR}/vae_stats.json",
vae_checkpoint_path=VAE_CHECKPOINT_PATH,
dit_config_path=MODEL_CONFIG_PATH,
dit_checkpoint_path=MODEL_CHECKPOINT_PATH,
)
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
const = quadratic_coef * (linear_steps ** 2)
quadratic_sigma_schedule = [
quadratic_coef * (i ** 2) + linear_coef * i + const
for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
def generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_inference_steps,
):
load_model()
# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
# cfg_schedule should be a list of floats of length num_inference_steps.
# For simplicity, we just use the same cfg scale at all timesteps,
# but more optimal schedules may use varying cfg, e.g:
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
cfg_schedule = [cfg_scale] * num_inference_steps
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"mochi_args": {
"sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule,
"num_inference_steps": num_inference_steps,
"batch_cfg": False,
},
"prompt": [prompt],
"negative_prompt": [negative_prompt],
"seed": seed,
}
final_frames = None
for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1):
final_frames = frames
assert isinstance(final_frames, np.ndarray)
assert final_frames.dtype == np.float32
final_frames = rearrange(final_frames, "t b h w c -> b t h w c")
final_frames = final_frames[0]
os.makedirs("outputs", exist_ok=True)
output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4")
with tempfile.TemporaryDirectory() as tmpdir:
frame_paths = []
for i, frame in enumerate(final_frames):
frame = (frame * 255).astype(np.uint8)
frame_img = Image.fromarray(frame)
frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png")
frame_img.save(frame_path)
frame_paths.append(frame_path)
frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}"
os.system(ffmpeg_cmd)
json_path = os.path.splitext(output_path)[0] + ".json"
with open(json_path, "w") as f:
json.dump(args, f, indent=4)
return output_path
@click.command()
@click.option("--prompt", default="""
a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes,
rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and
diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider
resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spiders body,
with a shallow depth of field to focus on the fine details of the spiders legs and the rough surface beneath it. The atmosphere should feel immersive and alive,
with the wind subtly blowing sand grains across the frame."""
, required=False, help="Prompt for video generation.")
@click.option(
"--negative_prompt", default="", help="Negative prompt for video generation."
)
@click.option("--width", default=848, type=int, help="Width of the video.")
@click.option("--height", default=480, type=int, help="Height of the video.")
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
@click.option("--seed", default=12345, type=int, help="Random seed.")
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
@click.option(
"--num_steps", default=64, type=int, help="Number of inference steps."
)
@click.option("--model_dir", required=True, help="Path to the model directory.")
def generate_cli(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
model_dir,
):
set_model_path(model_dir)
output = generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
)
click.echo(f"Video generated at: {output}")
if __name__ == "__main__":
generate_cli()

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,675 @@
import os
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention import sdpa_kernel
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (
residual_tanh_gated_rmsnorm,
)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
AttentionPool,
modulate,
pad_and_split_xy,
unify_streams,
)
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
FLASH_ATTN_IS_AVAILABLE = True
except ImportError:
FLASH_ATTN_IS_AVAILABLE = False
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
clip_feat_dim: Optional[int] = None,
pooled_caption_mlp_bias: bool = True,
use_transformer_engine: bool = False,
):
super().__init__()
self.dim_x = dim_x
self.dim_y = dim_y
self.num_heads = num_heads
self.head_dim = dim_x // num_heads
self.attn_drop = attn_drop
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
)
# Input layers.
self.qkv_bias = qkv_bias
self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device)
# Project text features to match visual features (dim_y -> dim_x)
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
# Query and key normalization for stability.
assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device)
self.k_norm_x = RMSNorm(self.head_dim, device=device)
self.q_norm_y = RMSNorm(self.head_dim, device=device)
self.k_norm_y = RMSNorm(self.head_dim, device=device)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)
self.proj_y = (
nn.Linear(dim_x, dim_y, bias=out_bias, device=device)
if update_y
else nn.Identity()
)
def run_qkv_y(self, y):
cp_rank, cp_size = get_cp_rank_size()
local_heads = self.num_heads // cp_size
if is_cp_active():
# Only predict local heads.
assert not self.qkv_bias
W_qkv_y = self.qkv_y.weight.view(
3, self.num_heads, self.head_dim, self.dim_y
)
W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
else:
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
q_y, k_y, v_y = qkv_y.unbind(2)
return q_y, k_y, v_y
def prepare_qkv(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
*,
scale_x: torch.Tensor,
scale_y: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
valid_token_indices: torch.Tensor,
):
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
#print("x in attn", x.dtype, x.device)
# Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
assert qkv_x.dtype == torch.bfloat16
qkv_x = all_to_all_collect_tokens(
qkv_x, self.num_heads
) # (3, B, N, local_h, head_dim)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
#print("y in attn", y.dtype, y.device)
#print(q_y.dtype, q_y.device)
#print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device)
# self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype)
# self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype)
# self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype)
# self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
# Unite streams
qkv = unify_streams(
q_x,
k_x,
v_x,
q_y,
k_y,
v_y,
valid_token_indices,
)
return qkv
@torch.compiler.disable()
def run_attention(
self,
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
*,
B: int,
L: int,
M: int,
cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int,
valid_token_indices: torch.Tensor,
):
_, cp_size = get_cp_rank_size()
N = cp_size * M
assert self.num_heads % cp_size == 0
local_heads = self.num_heads // cp_size
local_dim = local_heads * self.head_dim
total = qkv.size(0)
if FLASH_ATTN_IS_AVAILABLE:
with torch.autocast("cuda", enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
out = out.view(total, local_dim)
else:
raise NotImplementedError("Flash attention is currently required.")
print("qkv: ",qkv.shape, qkv.dtype, qkv.device)
expected_size = 2 * 44520 * 3 * 24 * 128
actual_size = qkv.numel()
print(f"Expected size: {expected_size}, Actual size: {actual_size}")
q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4)
with torch.autocast("cuda", enabled=False):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim)
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
assert x.size() == (B, N, local_dim)
assert y.size() == (B, L, local_dim)
x = x.view(B, N, local_heads, self.head_dim)
x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
x = self.proj_x(x) # (B, M, dim_x)
if is_cp_active():
y = all_gather(y) # (cp_size * B, L, local_heads * head_dim)
y = rearrange(
y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim
) # (B, L, dim_x)
y = self.proj_y(y) # (B, L, dim_y)
return x, y
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
*,
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
packed_indices: Dict[str, torch.Tensor] = None,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of asymmetric multi-modal attention.
Args:
x: (B, N, dim_x) tensor for visual tokens
y: (B, L, dim_y) tensor of text token features
packed_indices: Dict with keys for Flash Attention
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
y: (B, L, dim_y) tensor of text token features after multi-modal attention
"""
B, L, _ = y.shape
_, M, _ = x.shape
# Predict a packed QKV tensor from visual and text features.
# Don't checkpoint the all_to_all.
qkv = self.prepare_qkv(
x=x,
y=y,
scale_x=scale_x,
scale_y=scale_y,
rope_cos=rope_rotation.get("rope_cos"),
rope_sin=rope_rotation.get("rope_sin"),
valid_token_indices=packed_indices["valid_token_indices_kv"],
) # (total <= B * (N + L), 3, local_heads, head_dim)
x, y = self.run_attention(
qkv,
B=B,
L=L,
M=M,
cu_seqlens=packed_indices["cu_seqlens_kv"],
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
valid_token_indices=packed_indices["valid_token_indices_kv"],
)
return x, y
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
if self.update_y:
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
else:
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
hidden_size_y,
num_heads=num_heads,
update_y=update_y,
device=device,
**block_kwargs,
)
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
)
# MLP for text not needed in last block.
if self.update_y:
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
self.mlp_y = FeedForward(
in_features=hidden_size_y,
hidden_size=mlp_hidden_dim_y,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
**attn_kwargs,
):
"""Forward pass of a block.
Args:
x: (B, N, dim) tensor of visual tokens
c: (B, dim) tensor of conditioned features
y: (B, L, dim) tensor of text tokens
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
"""
N = x.size(1)
c = F.silu(c)
mod_x = self.mod_x(c)
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
mod_y = self.mod_y(c)
if self.update_y:
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
#@torch.compile(disable=not COMPILE_FINAL_LAYER)
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
):
super().__init__()
self.norm_final = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, device=device
)
self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, device=device
)
def forward(self, x, c):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
Ingests text embeddings instead of a label.
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
rope_theta: float = 10000.0,
device: Optional[torch.device] = None,
**block_kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.use_extended_posenc = use_extended_posenc
self.t5_token_length = t5_token_length
self.t5_feat_dim = t5_feat_dim
self.rope_theta = (
rope_theta # Scaling factor for frequency computation for temporal RoPE.
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size_x,
bias=patch_embed_bias,
device=device,
)
# Conditionings
# Timestep
self.t_embedder = TimestepEmbedder(
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale
)
# Caption Pooling (T5)
self.t5_y_embedder = AttentionPool(
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device
)
# Dense Embedding Projection (T5)
self.t5_yproj = nn.Linear(
t5_feat_dim, hidden_size_y, bias=True, device=device
)
# Initialize pos_frequencies as an empty parameter.
self.pos_frequencies = nn.Parameter(
torch.empty(3, self.num_heads, self.head_dim // 2, device=device)
)
# for depth 48:
# b = 0: AsymmetricJointBlock, update_y=True
# b = 1: AsymmetricJointBlock, update_y=True
# ...
# b = 46: AsymmetricJointBlock, update_y=True
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
blocks = []
for b in range(depth):
# Joint multi-modal block
update_y = b < depth - 1
block = AsymmetricJointBlock(
hidden_size_x,
hidden_size_y,
num_heads,
mlp_ratio_x=mlp_ratio_x,
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
device=device,
**block_kwargs,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.final_layer = FinalLayer(
hidden_size_x, patch_size, self.out_channels, device=device
)
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C=12, T, H, W) tensor of visual tokens
Returns:
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
"""
return self.x_embedder(x) # Convert BcTHW to BCN
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
#("X", x.shape)
with torch.profiler.record_function("x_emb_pe"):
# Visual patch embeddings with positional encoding.
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3
B = x.size(0)
with torch.profiler.record_function("rope_cis"):
# Construct position array of size [N, 3].
# pos[:, 0] is the frame index for each location,
# pos[:, 1] is the row index for each location, and
# pos[:, 2] is the column index for each location.
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
assert x.size(1) == N
pos = create_position_matrix(
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(
freqs=self.pos_frequencies, pos=pos
) # Each are (N, num_heads, dim // 2)
with torch.profiler.record_function("t_emb"):
# Global vector embedding for conditionings.
c_t = self.t_embedder(1 - sigma) # (B, D)
with torch.profiler.record_function("t5_pool"):
# Pool T5 tokens using attention pooler
# Note y_feat[1] contains T5 token features.
# print("B", B)
# print("t5 feat shape",t5_feat.shape)
# print("t5 mask shape", t5_mask.shape)
assert (
t5_feat.size(1) == self.t5_token_length
), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
assert (
t5_y_pool.size(0) == B
), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
c = c_t + t5_y_pool
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
sigma: torch.Tensor,
y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor],
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
):
"""Forward pass of DiT.
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region.
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0]
)
del y_mask
cp_rank, cp_size = get_cp_rank_size()
N = x.size(1)
M = N // cp_size
assert (
N % cp_size == 0
), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
if cp_size > 1:
x = x.narrow(1, cp_rank * M, M)
assert self.num_heads % cp_size == 0
local_heads = self.num_heads // cp_size
rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
packed_indices=packed_indices,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
patch = x.size(2)
x = all_gather(x)
x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
return x

View File

@ -0,0 +1,163 @@
import torch
import torch.distributed as dist
from einops import rearrange
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_RANK = None
_CONTEXT_PARALLEL_GROUP_SIZE = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return x
cp_rank, cp_size = get_cp_rank_size()
return x.tensor_split(cp_size, dim=dim)[cp_rank]
def set_cp_group(cp_group, ranks, global_rank):
global \
_CONTEXT_PARALLEL_GROUP, \
_CONTEXT_PARALLEL_RANK, \
_CONTEXT_PARALLEL_GROUP_SIZE, \
_CONTEXT_PARALLEL_GROUP_RANKS
if _CONTEXT_PARALLEL_GROUP is not None:
raise RuntimeError("CP group already initialized.")
_CONTEXT_PARALLEL_GROUP = cp_group
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
assert (
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
ranks
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
def get_cp_group():
if _CONTEXT_PARALLEL_GROUP is None:
raise RuntimeError("CP group not initialized")
return _CONTEXT_PARALLEL_GROUP
def is_cp_active():
return _CONTEXT_PARALLEL_GROUP is not None
def get_cp_rank_size():
if _CONTEXT_PARALLEL_GROUP:
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
else:
return 0, 1
class AllGatherIntoTensorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
ctx.reduce_dtype = reduce_dtype
ctx.group = group
ctx.batch_size = x.size(0)
group_size = dist.get_world_size(group)
x = x.contiguous()
output = torch.empty(
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
)
dist.all_gather_into_tensor(output, x, group=group)
return output
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return tensor
return AllGatherIntoTensorFunction.apply(
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
)
@torch.compiler.disable()
def _all_to_all_single(output, input, group):
# Disable compilation since torch compile changes contiguity.
assert input.is_contiguous(), "Input tensor must be contiguous."
assert output.is_contiguous(), "Output tensor must be contiguous."
return dist.all_to_all_single(output, input, group=group)
class CollectTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
"""Redistribute heads and receive tokens.
Args:
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
Returns:
qkv: shape: [3, B, N, local_heads, head_dim]
where M is the number of local tokens,
N = cp_size * M is the number of global tokens,
local_heads = num_heads // cp_size is the number of local heads.
"""
ctx.group = group
ctx.num_heads = num_heads
cp_size = dist.get_world_size(group)
assert num_heads % cp_size == 0
ctx.local_heads = num_heads // cp_size
qkv = rearrange(
qkv,
"B M (qkv G h d) -> G M h B (qkv d)",
qkv=3,
G=cp_size,
h=ctx.local_heads,
).contiguous()
output_chunks = torch.empty_like(qkv)
_all_to_all_single(output_chunks, qkv, group=group)
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = x.size()
x = x.view(B, M, 3, num_heads, -1)
return x.permute(2, 0, 1, 3, 4)
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
class CollectHeads(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
"""Redistribute tokens and receive heads.
Args:
x: Output of attention. Shape: [B, N, local_heads, head_dim]
Returns:
Shape: [B, M, num_heads * head_dim]
"""
ctx.group = group
ctx.local_heads = x.size(2)
ctx.head_dim = x.size(3)
group_size = dist.get_world_size(group)
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
output = torch.empty_like(x)
_all_to_all_single(output, x, group=group)
del x
return rearrange(output, "G h M B D -> B M (G h D)")
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Merge heads.
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)

View File

@ -0,0 +1,178 @@
import collections.abc
import math
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
*,
bias: bool = True,
timestep_scale: Optional[float] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.timestep_scale = timestep_scale
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
freqs.mul_(-math.log(max_period) / half).exp_()
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
if self.timestep_scale is not None:
t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class PooledCaptionEmbedder(nn.Module):
def __init__(
self,
caption_feature_dim: int,
hidden_size: int,
*,
bias: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.caption_feature_dim = caption_feature_dim
self.hidden_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
)
def forward(self, x):
return self.mlp(x)
class FeedForward(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
device: Optional[torch.device] = None,
):
super().__init__()
# keep parameter count and computation constant compared to standard FFN
hidden_size = int(2 * hidden_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_size = int(ffn_dim_multiplier * hidden_size)
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_size
self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device)
self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device)
def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
bias: bool = True,
dynamic_img_pad: bool = False,
device: Optional[torch.device] = None,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
self.flatten = flatten
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
device=device,
)
assert norm_layer is None
self.norm = (
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
)
def forward(self, x):
B, _C, T, H, W = x.shape
if not self.dynamic_img_pad:
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
else:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
#print("x",x.dtype, x.device)
#print(self.proj.weight.dtype, self.proj.weight.device)
x = self.proj(x)
# Flatten temporal and spatial dimensions.
if not self.flatten:
raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device))
self.register_parameter("bias", None)
def forward(self, x):
x_fp32 = x.float()
x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
return (x_normed * self.weight).type_as(x)

View File

@ -0,0 +1,23 @@
import torch
class ModulatedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale, eps=1e-6):
# Convert to fp32 for precision
x_fp32 = x.float()
scale_fp32 = scale.float()
# Compute RMS
mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
inv_rms = torch.rsqrt(mean_square + eps)
# Normalize and modulate
x_normed = x_fp32 * inv_rms
x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
return x_modulated.type_as(x)
def modulated_rmsnorm(x, scale, eps=1e-6):
return ModulatedRMSNorm.apply(x, scale, eps)

View File

@ -0,0 +1,27 @@
import torch
class ResidualTanhGatedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, x_res, gate, eps=1e-6):
# Convert to fp32 for precision
x_res_fp32 = x_res.float()
# Compute RMS
mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True)
scale = torch.rsqrt(mean_square + eps)
# Apply tanh to gate
tanh_gate = torch.tanh(gate).unsqueeze(1)
# Normalize and apply gated scaling
x_normed = x_res_fp32 * scale * tanh_gate
# Apply residual connection
output = x + x_normed.type_as(x)
return output
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps)

View File

@ -0,0 +1,88 @@
import functools
import math
import torch
def centers(start: float, stop, num, dtype=None, device=None):
"""linspace through bin centers.
Args:
start (float): Start of the range.
stop (float): End of the range.
num (int): Number of points.
dtype (torch.dtype): Data type of the points.
device (torch.device): Device of the points.
Returns:
centers (Tensor): Centers of the bins. Shape: (num,).
"""
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
return (edges[:-1] + edges[1:]) / 2
@functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,
pW: int,
device: torch.device,
dtype: torch.dtype,
*,
target_area: float = 36864,
):
"""
Args:
T: int - Temporal dimension
pH: int - Height dimension after patchify
pW: int - Width dimension after patchify
Returns:
pos: [T * pH * pW, 3] - position matrix
"""
with torch.no_grad():
# Create 1D tensors for each dimension
t = torch.arange(T, dtype=dtype)
# Positionally interpolate to area 36864.
# (3072x3072 frame with 16x16 patches = 192x192 latents).
# This automatically scales rope positions when the resolution changes.
# We use a large target area so the model is more sensitive
# to changes in the learned pos_frequencies matrix.
scale = math.sqrt(target_area / (pW * pH))
w = centers(-pW * scale / 2, pW * scale / 2, pW)
h = centers(-pH * scale / 2, pH * scale / 2, pH)
# Use meshgrid to create 3D grids
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation(
freqs: torch.Tensor,
pos: torch.Tensor,
):
"""
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
Args:
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
pos: [N, 3] - position of each token
num_heads: int
Returns:
freqs_cos: [N, num_heads, num_freqs] - cosine components
freqs_sin: [N, num_heads, num_freqs] - sine components
"""
with torch.autocast("cuda", enabled=False):
assert freqs.ndim == 3
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
freqs_cos = torch.cos(freqs_sum)
freqs_sin = torch.sin(freqs_sum)
return freqs_cos, freqs_sin

View File

@ -0,0 +1,34 @@
# Based on Llama3 Implementation.
import torch
def apply_rotary_emb_qk_real(
xqk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
Args:
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
Can be either just query or just key, or both stacked along some batch or * dim.
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
Returns:
torch.Tensor: The input tensor with rotary embeddings applied.
"""
assert xqk.dtype == torch.bfloat16
# Split the last dimension into even and odd parts
xqk_even = xqk[..., 0::2]
xqk_odd = xqk[..., 1::2]
# Apply rotation
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
assert out.dtype == torch.bfloat16
return out

View File

@ -0,0 +1,189 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
class PadSplitXY(torch.autograd.Function):
"""
Merge heads, pad and extract visual and text tokens,
and split along the sequence length.
"""
@staticmethod
def forward(
ctx,
xy: torch.Tensor,
indices: torch.Tensor,
B: int,
N: int,
L: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
indices: Valid token indices out of unpacked tensor. Shape: (total,)
Returns:
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
y: Text tokens. Shape: (B, L, num_heads * head_dim).
"""
ctx.save_for_backward(indices)
ctx.B, ctx.N, ctx.L = B, N, L
D = xy.size(1)
# Pad sequences to (B, N + L, dim).
assert indices.ndim == 1
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
indices = indices.unsqueeze(1).expand(
-1, D
) # (total,) -> (total, num_heads * head_dim)
output.scatter_(0, indices, xy)
xy = output.view(B, N + L, D)
# Split visual and text tokens along the sequence length.
return torch.tensor_split(xy, (N,), dim=1)
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
class UnifyStreams(torch.autograd.Function):
"""Unify visual and text streams."""
@staticmethod
def forward(
ctx,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_y: torch.Tensor,
k_y: torch.Tensor,
v_y: torch.Tensor,
indices: torch.Tensor,
):
"""
Args:
q_x: (B, N, num_heads, head_dim)
k_x: (B, N, num_heads, head_dim)
v_x: (B, N, num_heads, head_dim)
q_y: (B, L, num_heads, head_dim)
k_y: (B, L, num_heads, head_dim)
v_y: (B, L, num_heads, head_dim)
indices: (total <= B * (N + L))
Returns:
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
"""
ctx.save_for_backward(indices)
B, N, num_heads, head_dim = q_x.size()
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
D = num_heads * head_dim
q = torch.cat([q_x, q_y], dim=1)
k = torch.cat([k_x, k_y], dim=1)
v = torch.cat([v_x, v_y], dim=1)
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
indices = indices[:, None, None].expand(-1, 3, D)
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
return qkv.unflatten(2, (num_heads, head_dim))
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)

View File

@ -0,0 +1,445 @@
import json
import os
import random
from functools import partial
from typing import Dict, List
from safetensors.torch import load_file
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import yaml
from einops import rearrange, repeat
from omegaconf import OmegaConf
from torch import nn
from torch.distributed.fsdp import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Block
from .dit.joint_model.context_parallel import get_cp_rank_size
from .utils import Timer
from tqdm import tqdm
from comfy.utils import ProgressBar
T5_MODEL = "weights/T5"
MAX_T5_TOKEN_LENGTH = 256
class T5_Tokenizer:
"""Wrapper around Hugging Face tokenizer for T5
Args:
model_name(str): Name of tokenizer to load.
"""
def __init__(self):
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
"""
Args:
prompt (str): The input text to tokenize.
padding (str): The padding strategy.
truncation (bool): Flag indicating whether to truncate the tokens.
return_tensors (str): Flag indicating whether to return tensors.
max_length (int): The max length of the tokens.
"""
assert (
not max_length or max_length == MAX_T5_TOKEN_LENGTH
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
truncation=truncation,
return_tensors=return_tensors,
return_attention_mask=True,
)
return tokenized_output
def unnormalize_latents(
z: torch.Tensor,
mean: torch.Tensor,
std: torch.Tensor,
) -> torch.Tensor:
"""Unnormalize latents. Useful for decoding DiT samples.
Args:
z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
Returns:
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
"""
mean = mean[:, None, None, None]
std = std[:, None, None, None]
assert z.ndim == 5
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
),
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=device_id,
sync_module_states=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
Args:
N: Number of visual tokens.
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
Returns:
packed_indices: Dict with keys for Flash Attention:
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
in the packed sequence.
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
"""
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(text_mask) == 1
text_mask = text_mask[0]
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
def shift_sigma(
sigma: np.ndarray,
shift: float,
):
"""Shift noise standard deviation toward higher values.
Useful for training a model at high resolutions,
or sampling more finely at high noise levels.
Equivalent to:
sigma_shift = shift / (shift + 1 / sigma - 1)
except for sigma = 0.
Args:
sigma: noise standard deviation in [0, 1]
shift: shift factor >= 1.
For shift > 1, shifts sigma to higher values.
For shift = 1, identity function.
"""
return shift * sigma / (shift * sigma + 1 - sigma)
class T2VSynthMochiModel:
def __init__(
self,
*,
device_id: int,
vae_stats_path: str,
dit_checkpoint_path: str,
):
super().__init__()
t = Timer()
self.device = torch.device(device_id)
#self.t5_tokenizer = T5_Tokenizer()
# with t("load_text_encs"):
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
with t("construct_dit"):
from .dit.joint_model.asymm_models_joint import (
AsymmDiTJoint,
)
model: nn.Module = torch.nn.utils.skip_init(
AsymmDiTJoint,
depth=48,
patch_size=2,
num_heads=24,
hidden_size_x=3072,
hidden_size_y=1536,
mlp_ratio_x=4.0,
mlp_ratio_y=4.0,
in_channels=12,
qk_norm=True,
qkv_bias=False,
out_bias=True,
patch_embed_bias=True,
timestep_mlp_bias=True,
timestep_scale=1000.0,
t5_feat_dim=4096,
t5_token_length=256,
rope_theta=10000.0,
)
with t("dit_load_checkpoint"):
model.load_state_dict(load_file(dit_checkpoint_path))
with t("fsdp_dit"):
self.dit = model
self.dit.eval()
for name, param in self.dit.named_parameters():
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn)
else:
param.data = param.data.to(torch.bfloat16)
vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
t.print_stats()
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
B = len(prompts)
print(f"Getting conditioning for {B} prompts")
assert (
0 <= zero_last_n_prompts <= B
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
tokenize_kwargs = dict(
prompt=prompts,
padding="max_length",
return_tensors="pt",
truncation=True,
)
t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH)
caption_input_ids_t5 = t5_toks["input_ids"]
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
del t5_toks
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
if zero_last_n_prompts > 0:
# Zero the last N prompts
caption_input_ids_t5[-zero_last_n_prompts:] = 0
caption_attention_mask_t5[-zero_last_n_prompts:] = False
caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True)
caption_attention_mask_t5 = caption_attention_mask_t5.to(
self.device, non_blocking=True
)
y_mask = [caption_attention_mask_t5]
y_feat = []
self.t5_enc.to(self.device)
y_feat.append(
self.t5_enc(
caption_input_ids_t5, caption_attention_mask_t5
).last_hidden_state.detach().to(torch.float32)
)
print(y_feat.shape)
print(y_feat[0])
self.t5_enc.to("cpu")
# Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
return dict(y_mask=y_mask, y_feat=y_feat)
def get_packed_indices(self, y_mask, *, lT, lW, lH):
patch_size = 2
N = lT * lH * lW // (patch_size**2)
assert len(y_mask) == 1
packed_indices = compute_packed_indices(N, y_mask)
self.move_to_device_(packed_indices)
return packed_indices
def move_to_device_(self, sample):
if isinstance(sample, dict):
for key in sample.keys():
if isinstance(sample[key], torch.Tensor):
sample[key] = sample[key].to(self.device, non_blocking=True)
@torch.inference_mode(mode=True)
def run(self, args, stream_results):
random.seed(args["seed"])
np.random.seed(args["seed"])
torch.manual_seed(args["seed"])
generator = torch.Generator(device=self.device)
generator.manual_seed(args["seed"])
# assert (
# len(args["prompt"]) == 1
# ), f"Expected exactly one prompt, got {len(args['prompt'])}"
#prompt = args["prompt"][0]
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
B = 1
w = args["width"]
h = args["height"]
t = args["num_frames"]
batch_cfg = args["mochi_args"]["batch_cfg"]
sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule")
assert (
len(cfg_schedule) == sample_steps
), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}"
sigma_schedule = args["mochi_args"].get("sigma_schedule")
if sigma_schedule:
assert (
len(sigma_schedule) == sample_steps + 1
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}"
# if batch_cfg:
# sample_batched = self.get_conditioning(
# [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0
# )
# else:
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
spatial_downsample = 8
temporal_downsample = 6
latent_t = (t - 1) // temporal_downsample + 1
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
in_channels = 12
z = torch.randn(
(B, in_channels, latent_t, latent_h, latent_w),
device=self.device,
generator=generator,
dtype=torch.float32,
)
# if batch_cfg:
# sample_batched["packed_indices"] = self.get_packed_indices(
# sample_batched["y_mask"], **latent_dims
# )
# z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
# else:
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
# print(sample["y_mask"])
# print(type(sample["y_mask"]))
# print(sample["y_mask"][0].shape)
# print(sample["y_feat"])
# print(type(sample["y_feat"]))
# print(sample["y_feat"][0].shape)
print(sample_null["y_mask"])
print(type(sample_null["y_mask"]))
print(sample_null["y_mask"][0].shape)
print(sample_null["y_feat"])
print(type(sample_null["y_feat"]))
print(sample_null["y_feat"][0].shape)
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
def model_fn(*, z, sigma, cfg_scale):
#print("z", z.dtype, z.device)
#print("sigma", sigma.dtype, sigma.device)
self.dit.to(self.device)
# if batch_cfg:
# with torch.autocast("cuda", dtype=torch.bfloat16):
# out = self.dit(z, sigma, **sample_batched)
# out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
#else:
nonlocal sample, sample_null
with torch.autocast("cuda", dtype=torch.bfloat16):
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
assert out_cond.shape == out_uncond.shape
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
comfy_pbar = ProgressBar(sample_steps)
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
sigma = sigma_schedule[i]
dsigma = sigma - sigma_schedule[i + 1]
# `pred` estimates `z_0 - eps`.
pred, output_cond = model_fn(
z=z,
sigma=torch.full(
[B] if not batch_cfg else [B * 2], sigma, device=z.device
),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
output_cond = output_cond.to(z)
#if stream_results:
# yield i / sample_steps, None, False
z = z + dsigma * pred
comfy_pbar.update(1)
cp_rank, cp_size = get_cp_rank_size()
if batch_cfg:
z = z[:B]
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to("cpu")
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
print("samples", samples.shape, samples.dtype, samples.device)
return samples

33
mochi_preview/utils.py Normal file
View File

@ -0,0 +1,33 @@
import time
class Timer:
def __init__(self):
self.times = {} # Dictionary to store times per stage
def __call__(self, name):
print(f"Timing {name}")
return self.TimerContextManager(self, name)
def print_stats(self):
total_time = sum(self.times.values())
# Print table header
print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent"))
for name, t in self.times.items():
percent = (t / total_time) * 100 if total_time > 0 else 0
print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent))
class TimerContextManager:
def __init__(self, outer, name):
self.outer = outer # Reference to the Timer instance
self.name = name
self.start_time = None
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
end_time = time.perf_counter()
elapsed = end_time - self.start_time
self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,152 @@
from typing import Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
"""
Forward pass that handles communication between ranks for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
frames_to_send: int, number of frames to communicate between ranks
Returns:
output: Tensor of shape (B, C, T', H, W)
"""
cp_rank, cp_world_size = cp.get_cp_rank_size()
if frames_to_send == 0 or cp_world_size == 1:
return x
group = get_cp_group()
global_rank = dist.get_rank()
# Send to next rank
if cp_rank < cp_world_size - 1:
assert x.size(2) >= frames_to_send
tail = x[:, :, -frames_to_send:].contiguous()
dist.send(tail, global_rank + 1, group=group)
# Receive from previous rank
if cp_rank > 0:
B, C, _, H, W = x.shape
recv_buffer = torch.empty(
(B, C, frames_to_send, H, W),
dtype=x.dtype,
device=x.device,
)
dist.recv(recv_buffer, global_rank - 1, group=group)
x = torch.cat([recv_buffer, x], dim=2)
return x
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
if max_T > x.size(2):
pad_T = max_T - x.size(2)
pad_dims = (0, 0, 0, 0, 0, pad_T)
return F.pad(x, pad_dims)
return x
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
"""
Gathers all frames from all processes for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
output: Tensor of shape (B, C, T_total, H, W)
"""
cp_rank, cp_size = get_cp_rank_size()
cp_group = get_cp_group()
# Ensure the tensor is contiguous for collective operations
x = x.contiguous()
# Get the local time dimension size
local_T = x.size(2)
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
# Gather all T sizes from all processes
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
dist.all_gather(all_T, local_T_tensor, group=cp_group)
all_T = [t.item() for t in all_T]
# Pad the tensor at the end of the time dimension to match max_T
max_T = max(all_T)
x = _pad_to_max(x, max_T).contiguous()
# Prepare a list to hold the gathered tensors
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
# Perform the all_gather operation
dist.all_gather(gathered_x, x, group=cp_group)
# Slice each gathered tensor back to its original T size
for idx, t_size in enumerate(all_T):
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
return torch.cat(gathered_x, dim=2)
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
"""Estimate memory usage based on input tensor size and data type."""
element_size = input.element_size() # Size in bytes of each element
memory_bytes = input.numel() * element_size
memory_gb = memory_bytes / 1024**3
return memory_gb > max_gb
class ContextParallelCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
**kwargs,
):
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
context_size = self.kernel_size[0] - 1
if cp_rank == 0:
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
if cp_world_size == 1:
return super().forward(x)
if all(s == 1 for s in self.stride):
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
x = super().forward(x)
x_chunks = x.tensor_split(cp_world_size, dim=2)
assert len(x_chunks) == cp_world_size
return x_chunks[cp_rank]

815
mochi_preview/vae/model.py Normal file
View File

@ -0,0 +1,815 @@
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
class GroupNormSpatial(nn.GroupNorm):
"""
GroupNorm applied per-frame.
"""
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
# Run group norm in chunks.
output = torch.empty_like(x)
for b in range(0, B * T, chunk_size):
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
class SafeConv3d(torch.nn.Conv3d):
"""
NOTE: No support for padding along time dimension.
Input must already be padded along time.
"""
def forward(self, input):
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if memory_count > 2:
part_num = int(memory_count / 2) + 1
k = self.kernel_size[0]
input_idx = torch.arange(k - 1, input.size(2))
input_chunks_idx = torch.chunk(input_idx, part_num, dim=0)
# assert self.kernel_size == (3, 3, 3), f"kernel_size {self.kernel_size} != (3, 3, 3)"
assert self.stride[0] == 1, f"stride {self.stride}"
assert self.dilation[0] == 1, f"dilation {self.dilation}"
assert self.padding[0] == 0, f"padding {self.padding}"
# Comptue output size
assert not input.requires_grad
B, _, T_in, H_in, W_in = input.shape
output_size = (
B,
self.out_channels,
T_in - k + 1,
H_in // self.stride[1],
W_in // self.stride[2],
)
output = torch.empty(output_size, dtype=input.dtype, device=input.device)
for input_chunk_idx in input_chunks_idx:
input_s = input_chunk_idx[0] - k + 1
input_e = input_chunk_idx[-1] + 1
input_chunk = input[:, :, input_s:input_e, :, :]
output_chunk = super(SafeConv3d, self).forward(input_chunk)
output_s = input_s
output_e = output_s + output_chunk.size(2)
output[:, :, output_s:output_e, :, :] = output_chunk
return output
else:
return super(SafeConv3d, self).forward(input)
class StridedSafeConv3d(torch.nn.Conv3d):
def forward(self, input, local_shard: bool = False):
assert self.stride[0] == self.kernel_size[0]
assert self.dilation[0] == 1
assert self.padding[0] == 0
kernel_size = self.kernel_size[0]
stride = self.stride[0]
T_in = input.size(2)
T_out = T_in // kernel_size
# Parallel implementation.
if local_shard:
idx = torch.arange(T_out)
idx = local_shard(idx, dim=0)
start = idx.min() * stride
end = idx.max() * stride + kernel_size
local_input = input[:, :, start:end, :, :]
return torch.nn.Conv3d.forward(self, local_input)
raise NotImplementedError
class ContextParallelConv3d(SafeConv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
causal: bool = True,
context_parallel: bool = True,
**kwargs,
):
self.causal = causal
self.context_parallel = context_parallel
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
# Compute padding amounts.
context_size = self.kernel_size[0] - 1
if self.causal:
pad_front = context_size
pad_back = 0
else:
pad_front = context_size // 2
pad_back = context_size - pad_front
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel and cp_world_size == 1:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
if cp_rank == 0:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
elif cp_rank == cp_world_size - 1 and pad_back:
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
if self.context_parallel and cp_world_size == 1:
return super().forward(x)
if self.stride[0] == 1:
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
assert (
x.dtype == torch.bfloat16
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
return StridedSafeConv3d.forward(self, x, local_shard=True)
class Conv1x1(nn.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class DepthToSpaceTime(nn.Module):
def __init__(
self,
temporal_expansion: int,
spatial_expansion: int,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# When printed, this module should show the temporal and spatial expansion factors.
def extra_repr(self):
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
Returns:
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
"""
x = rearrange(
x,
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
st=self.temporal_expansion,
sh=self.spatial_expansion,
sw=self.spatial_expansion,
)
cp_rank, _ = get_cp_rank_size()
if self.temporal_expansion > 1 and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
assert all(x.shape)
x = x[:, :, self.temporal_expansion - 1 :]
assert all(x.shape)
return x
def norm_fn(
in_channels: int,
affine: bool = True,
):
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
class ResBlock(nn.Module):
"""Residual block that preserves the spatial dimensions."""
def __init__(
self,
channels: int,
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
):
super().__init__()
self.channels = channels
assert causal
self.stack = nn.Sequential(
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
ContextParallelConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
ContextParallelConv3d(
in_channels=channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
causal=causal,
),
)
self.attn_block = attn_block if attn_block else nn.Identity()
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
"""
residual = x
x = self.stack(x)
x = x + residual
del residual
return self.attn_block(x)
def prepare_for_attention(qkv: torch.Tensor, head_dim: int, qk_norm: bool = True):
"""Prepare qkv tensor for attention and normalize qk.
Args:
qkv: Input tensor. Shape: [B, L, 3 * num_heads * head_dim].
Returns:
q, k, v: qkv tensor split into q, k, v. Shape: [B, num_heads, L, head_dim].
"""
assert qkv.ndim == 3 # [B, L, 3 * num_heads * head_dim]
assert qkv.size(2) % (3 * head_dim) == 0
num_heads = qkv.size(2) // (3 * head_dim)
qkv = qkv.unflatten(2, (3, num_heads, head_dim))
q, k, v = qkv.unbind(2) # [B, L, num_heads, head_dim]
q = q.transpose(1, 2) # [B, num_heads, L, head_dim]
k = k.transpose(1, 2) # [B, num_heads, L, head_dim]
v = v.transpose(1, 2) # [B, num_heads, L, head_dim]
if qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
# Mixed precision can change the dtype of normed q/k to float32.
q = q.to(dtype=qkv.dtype)
k = k.to(dtype=qkv.dtype)
return q, k, v
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
*,
chunk_size=2**15,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = prepare_for_attention(qkv, self.head_dim, qk_norm=self.qk_norm)
attn_kwargs = dict(
attn_mask=None,
dropout_p=0.0,
is_causal=True,
scale=self.head_dim**-0.5,
)
if q.size(0) <= chunk_size:
x = F.scaled_dot_product_attention(
q, k, v, **attn_kwargs
) # [B, num_heads, t, head_dim]
else:
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
# Chunks of 2**16 and up cause an error.
x = torch.empty_like(q)
for i in range(0, q.size(0), chunk_size):
qc = q[i : i + chunk_size]
kc = k[i : i + chunk_size]
vc = v[i : i + chunk_size]
chunk = F.scaled_dot_product_attention(qc, kc, vc, **attn_kwargs)
x[i : i + chunk_size].copy_(chunk)
assert x.size(0) == q.size(0)
x = x.transpose(1, 2) # [B, t, num_heads, head_dim]
x = x.flatten(2) # [B, t, num_heads * head_dim]
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks: int,
*,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
**block_kwargs,
):
super().__init__()
blocks = []
for _ in range(num_res_blocks):
blocks.append(block_fn(in_channels, **block_kwargs))
self.blocks = nn.Sequential(*blocks)
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
# Change channels in the final convolution layer.
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.d2st = DepthToSpaceTime(
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
def forward(self, x):
x = self.blocks(x)
x = self.proj(x)
x = self.d2st(x)
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
ContextParallelConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
padding_mode="replicate",
bias=True,
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
num_freqs = (stop - start) // step
assert inputs.ndim == 5
C = inputs.size(1)
# Create Base 2 Fourier features.
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
assert num_freqs == len(freqs)
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
C = inputs.shape[1]
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w.
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
return torch.cat(
[
inputs,
torch.sin(h),
torch.cos(h),
],
dim=1,
)
class FourierFeatures(nn.Module):
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
super().__init__()
self.start = start
self.stop = stop
self.step = step
def forward(self, inputs):
"""Add Fourier features to inputs.
Args:
inputs: Input tensor. Shape: [B, C, T, H, W]
Returns:
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
"""
return add_fourier_features(inputs, self.start, self.stop, self.step)
class Decoder(nn.Module):
def __init__(
self,
*,
out_channels: int = 3,
latent_dim: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
temporal_expansions: Optional[List[int]] = None,
spatial_expansions: Optional[List[int]] = None,
has_attention: List[bool],
output_norm: bool = True,
nonlinearity: str = "silu",
output_nonlinearity: str = "silu",
causal: bool = True,
**block_kwargs,
):
super().__init__()
self.input_channels = latent_dim
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.output_nonlinearity = output_nonlinearity
assert nonlinearity == "silu"
assert causal
ch = [mult * base_channels for mult in channel_multipliers]
self.num_up_blocks = len(ch) - 1
assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = []
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
first_block.append(
block_fn(
ch[-1],
has_attention=has_attention[-1],
causal=causal,
**block_kwargs,
)
)
blocks.append(nn.Sequential(*first_block))
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
upsample_block_fn = CausalUpsampleBlock
for i in range(self.num_up_blocks):
block = upsample_block_fn(
ch[-i - 1],
ch[-i - 2],
num_res_blocks=num_res_blocks[-i - 2],
has_attention=has_attention[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
causal=causal,
**block_kwargs,
)
blocks.append(block)
assert not output_norm
# Last block. Preserve channel count.
last_block = []
for _ in range(num_res_blocks[0]):
last_block.append(
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks)
self.output_proj = Conv1x1(ch[0], out_channels)
def forward(self, x):
"""Forward pass.
Args:
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
Returns:
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
T + 1 = (t - 1) * 4.
H = h * 16, W = w * 16.
"""
for block in self.blocks:
x = block(x)
if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training)
else:
assert (
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous()
def make_broadcastable(
tensor: torch.Tensor,
axis: int,
ndim: int,
) -> torch.Tensor:
"""
Reshapes the input tensor to have singleton dimensions in all axes except the specified axis.
Args:
tensor (torch.Tensor): The tensor to reshape. Typically 1D.
axis (int): The axis along which the tensor should retain its original size.
ndim (int): The total number of dimensions the reshaped tensor should have.
Returns:
torch.Tensor: The reshaped tensor with shape suitable for broadcasting.
"""
if tensor.dim() != 1:
raise ValueError(f"Expected tensor to be 1D, but got {tensor.dim()}D tensor.")
axis = (axis + ndim) % ndim # Ensure the axis is within the tensor dimensions
shape = [1] * ndim # Start with all dimensions as 1
shape[axis] = tensor.size(0) # Set the specified axis to the size of the tensor
return tensor.view(*shape)
def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
"""
Blends two tensors `a` and `b` along the specified axis using linear interpolation.
Args:
a (torch.Tensor): The first tensor.
b (torch.Tensor): The second tensor. Must have the same shape as `a`.
axis (int): The axis along which to perform the blending.
Returns:
torch.Tensor: The blended tensor.
"""
assert (
a.shape == b.shape
), f"Tensors must have the same shape, got {a.shape} and {b.shape}"
steps = a.size(axis)
# Create a weight tensor that linearly interpolates from 0 to 1
start = 1 / (steps + 1)
end = steps / (steps + 1)
weight = torch.linspace(start, end, steps=steps, device=a.device, dtype=a.dtype)
# Make the weight tensor broadcastable across all dimensions
weight = make_broadcastable(weight, axis, a.dim())
# Perform the blending
return a * (1 - weight) + b * weight
def blend_horizontal(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
if overlap == 0:
return torch.cat([a, b], dim=-1)
assert a.size(-1) >= overlap
assert b.size(-1) >= overlap
a_left, a_overlap = a[..., :-overlap], a[..., -overlap:]
b_overlap, b_right = b[..., :overlap], b[..., overlap:]
return torch.cat([a_left, blend(a_overlap, b_overlap, -1), b_right], dim=-1)
def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
if overlap == 0:
return torch.cat([a, b], dim=-2)
assert a.size(-2) >= overlap
assert b.size(-2) >= overlap
a_top, a_overlap = a[..., :-overlap, :], a[..., -overlap:, :]
b_overlap, b_bottom = b[..., :overlap, :], b[..., overlap:, :]
return torch.cat([a_top, blend(a_overlap, b_overlap, -2), b_bottom], dim=-2)
def nearest_multiple(x: int, multiple: int) -> int:
return round(x / multiple) * multiple
def apply_tiled(
fn: Callable[[torch.Tensor], torch.Tensor],
x: torch.Tensor,
num_tiles_w: int,
num_tiles_h: int,
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
# Use a factor of 2 times the latent downsample factor.
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
):
if num_tiles_w == 1 and num_tiles_h == 1:
return fn(x)
assert (
num_tiles_w & (num_tiles_w - 1) == 0
), f"num_tiles_w={num_tiles_w} must be a power of 2"
assert (
num_tiles_h & (num_tiles_h - 1) == 0
), f"num_tiles_h={num_tiles_h} must be a power of 2"
H, W = x.shape[-2:]
assert H % min_block_size == 0
assert W % min_block_size == 0
ov = overlap // 2
assert ov % min_block_size == 0
if num_tiles_w >= 2:
# Subdivide horizontally.
half_W = nearest_multiple(W // 2, min_block_size)
left = x[..., :, : half_W + ov]
right = x[..., :, half_W - ov :]
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
left = apply_tiled(
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
)
right = apply_tiled(
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
)
if left is None or right is None:
return None
# If `fn` changed the resolution, adjust the overlap.
resample_factor = left.size(-1) / (half_W + ov)
out_overlap = int(overlap * resample_factor)
return blend_horizontal(left, right, out_overlap)
if num_tiles_h >= 2:
# Subdivide vertically.
half_H = nearest_multiple(H // 2, min_block_size)
top = x[..., : half_H + ov, :]
bottom = x[..., half_H - ov :, :]
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
top = apply_tiled(
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
)
bottom = apply_tiled(
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
)
if top is None or bottom is None:
return None
# If `fn` changed the resolution, adjust the overlap.
resample_factor = top.size(-2) / (half_H + ov)
out_overlap = int(overlap * resample_factor)
return blend_vertical(top, bottom, out_overlap)
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")

356
nodes.py Normal file
View File

@ -0,0 +1,356 @@
import os
import torch
import torch.nn as nn
import folder_paths
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from einops import rearrange
from contextlib import nullcontext
from PIL import Image
import numpy as np
import json
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
from .mochi_preview.vae.model import Decoder
script_directory = os.path.dirname(os.path.abspath(__file__))
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
const = quadratic_coef * (linear_steps ** 2)
quadratic_sigma_schedule = [
quadratic_coef * (i ** 2) + linear_coef * i + const
for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
class DownloadAndLoadMochiModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (
[
"mochi_preview_dit_fp8_e4m3fn.safetensors",
],
),
"vae": (
[
"mochi_preview_vae_bf16.safetensors",
],
),
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
),
},
}
RETURN_TYPES = ("MOCHIMODEL", "MOCHIVAE",)
RETURN_NAMES = ("mochi_model", "mochi_vae" )
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
# Transformer model
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
model_path = os.path.join(model_download_path, model)
repo_id = "kijai/Mochi_preview_comfy"
if not os.path.exists(model_path):
log.info(f"Downloading mochi model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=repo_id,
allow_patterns=[f"*{model}*"],
local_dir=model_download_path,
local_dir_use_symlinks=False,
)
# VAE
vae_download_path = os.path.join(folder_paths.models_dir, 'vae', 'mochi')
vae_path = os.path.join(vae_download_path, vae)
if not os.path.exists(vae_path):
log.info(f"Downloading mochi VAE to: {vae_path}")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=repo_id,
allow_patterns=[f"*{vae}*"],
local_dir=model_download_path,
local_dir_use_symlinks=False,
)
model = T2VSynthMochiModel(
device_id=0,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
)
vae = Decoder(
out_channels=3,
base_channels=128,
channel_multipliers=[1, 2, 4, 6],
temporal_expansions=[1, 2, 3],
spatial_expansions=[2, 2, 2],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
causal=True,
)
decoder_sd = load_torch_file(vae_path)
vae.load_state_dict(decoder_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu")
del decoder_sd
return (model, vae,)
class MochiTextEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
},
"optional": {
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"force_offload": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditioning",)
FUNCTION = "process"
CATEGORY = "MochiWrapper"
def process(self, clip, prompt, strength=1.0, force_offload=True):
max_tokens = 256
load_device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device()
#print(clip.tokenizer.t5xxl)
clip.tokenizer.t5xxl.pad_to_max_length = True
clip.tokenizer.t5xxl.max_length = max_tokens
clip.cond_stage_model.t5xxl.return_attention_masks = True
clip.cond_stage_model.t5xxl.enable_attention_masks = True
clip.cond_stage_model.t5_attention_mask = True
clip.cond_stage_model.to(load_device)
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
if embeds.shape[1] > 256:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
embeds *= strength
if force_offload:
clip.cond_stage_model.to(offload_device)
t5_embeds = {
"embeds": embeds,
"attention_mask": attention_mask["attention_mask"].bool(),
}
return (t5_embeds, )
class MochiSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MOCHIMODEL",),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"width": ("INT", {"default": 848, "min": 128, "max": 2048, "step": 8}),
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
"num_frames": ("INT", {"default": 49, "min": 7, "max": 1024, "step": 6}),
"steps": ("INT", {"default": 50, "min": 2}),
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("model", "samples",)
FUNCTION = "process"
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames):
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"mochi_args": {
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
"cfg_schedule": [cfg] * steps,
"num_inference_steps": steps,
"batch_cfg": False,
},
"positive_embeds": positive,
"negative_embeds": negative,
"seed": seed,
}
latents = model.run(args, stream_results=False)
mm.soft_empty_cache()
return ({"samples": latents},)
class MochiDecode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("MOCHIVAE",),
"samples": ("LATENT", ),
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
},
"optional": {
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
samples = samples["samples"]
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
self.num_latent_frames_batch_size = 6
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
vae.to(device)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
if not enable_vae_tiling:
samples = vae(samples)
else:
batch_size, num_channels, num_frames, height, width = samples.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = samples[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = vae(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
samples = torch.cat(result_rows, dim=3)
vae.to(offload_device)
#print("samples", samples.shape, samples.dtype, samples.device)
samples = samples.float()
samples = (samples + 1.0) / 2.0
samples.clamp_(0.0, 1.0)
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float()
#print(frames.shape)
return (frames,)
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
"MochiSampler": MochiSampler,
"MochiDecode": MochiDecode,
"MochiTextEncode": MochiTextEncode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
"MochiSampler": "Mochi Sampler",
"MochiDecode": "Mochi Decode",
"MochiTextEncode": "Mochi TextEncode",
}

18
readme.md Normal file
View File

@ -0,0 +1,18 @@
# ComfyUI wrapper nodes for Mochi video gen: https://github.com/genmoai/models
# WORK IN PROGRESS
## Requires flash_attn !
Depending on frame count can fit under 20GB, VAE decoding is heavy and there is experimental tiled decoder (taken from CogVideoX -diffusers code) which allows higher frame counts, so far highest I've done is 97 with the default tile size 2x2 grid.
Models:
https://huggingface.co/Kijai/Mochi_preview_comfy/tree/main
model to: `ComfyUI/models/diffusion_models/mochi`
vae to: `ComfyUI/models/vae/mochi`
There is autodownload node (also will be normal loader node)