initial commit
This commit is contained in:
parent
6fa487d3b9
commit
b80cb4a691
3
__init__.py
Normal file
3
__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
||||
BIN
__pycache__/__init__.cpython-312.pyc
Normal file
BIN
__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/nodes.cpython-312.pyc
Normal file
BIN
__pycache__/nodes.cpython-312.pyc
Normal file
Binary file not shown.
4
configs/vae_stats.json
Normal file
4
configs/vae_stats.json
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
|
||||
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
|
||||
}
|
||||
213
infer.py
Normal file
213
infer.py
Normal file
@ -0,0 +1,213 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
#import ray
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||||
|
||||
model = None
|
||||
model_path = "weights"
|
||||
def noexcept(f):
|
||||
try:
|
||||
return f()
|
||||
except:
|
||||
pass
|
||||
# class MochiWrapper:
|
||||
# def __init__(self, *, num_workers, **actor_kwargs):
|
||||
# super().__init__()
|
||||
# RemoteClass = ray.remote(T2VSynthMochiModel)
|
||||
# self.workers = [
|
||||
# RemoteClass.options(num_gpus=1).remote(
|
||||
# device_id=0, world_size=num_workers, local_rank=i, **actor_kwargs
|
||||
# )
|
||||
# for i in range(num_workers)
|
||||
# ]
|
||||
# # Ensure the __init__ method has finished on all workers
|
||||
# for worker in self.workers:
|
||||
# ray.get(worker.__ray_ready__.remote())
|
||||
# self.is_loaded = True
|
||||
|
||||
# def __call__(self, args):
|
||||
# work_refs = [
|
||||
# worker.run.remote(args, i == 0) for i, worker in enumerate(self.workers)
|
||||
# ]
|
||||
|
||||
# try:
|
||||
# for result in work_refs[0]:
|
||||
# yield ray.get(result)
|
||||
|
||||
# # Handle the (very unlikely) edge-case where a worker that's not the 1st one
|
||||
# # fails (don't want an uncaught error)
|
||||
# for result in work_refs[1:]:
|
||||
# ray.get(result)
|
||||
# except Exception as e:
|
||||
# # Get exception from other workers
|
||||
# for ref in work_refs[1:]:
|
||||
# noexcept(lambda: ray.get(ref))
|
||||
# raise e
|
||||
|
||||
def set_model_path(path):
|
||||
global model_path
|
||||
model_path = path
|
||||
|
||||
|
||||
def load_model():
|
||||
global model, model_path
|
||||
if model is None:
|
||||
#ray.init()
|
||||
MOCHI_DIR = model_path
|
||||
VAE_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_vae_bf16.safetensors"
|
||||
MODEL_CONFIG_PATH = f"{MOCHI_DIR}/dit-config.yaml"
|
||||
MODEL_CHECKPOINT_PATH = f"{MOCHI_DIR}/mochi_preview_dit_fp8_e4m3fn.safetensors"
|
||||
|
||||
model = T2VSynthMochiModel(
|
||||
device_id=0,
|
||||
world_size=1,
|
||||
local_rank=0,
|
||||
vae_stats_path=f"{MOCHI_DIR}/vae_stats.json",
|
||||
vae_checkpoint_path=VAE_CHECKPOINT_PATH,
|
||||
dit_config_path=MODEL_CONFIG_PATH,
|
||||
dit_checkpoint_path=MODEL_CHECKPOINT_PATH,
|
||||
)
|
||||
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||
const = quadratic_coef * (linear_steps ** 2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||
for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
def generate_video(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
width,
|
||||
height,
|
||||
num_frames,
|
||||
seed,
|
||||
cfg_scale,
|
||||
num_inference_steps,
|
||||
):
|
||||
load_model()
|
||||
|
||||
# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
|
||||
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
|
||||
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
|
||||
|
||||
# cfg_schedule should be a list of floats of length num_inference_steps.
|
||||
# For simplicity, we just use the same cfg scale at all timesteps,
|
||||
# but more optimal schedules may use varying cfg, e.g:
|
||||
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
|
||||
cfg_schedule = [cfg_scale] * num_inference_steps
|
||||
|
||||
args = {
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_frames": num_frames,
|
||||
"mochi_args": {
|
||||
"sigma_schedule": sigma_schedule,
|
||||
"cfg_schedule": cfg_schedule,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"batch_cfg": False,
|
||||
},
|
||||
"prompt": [prompt],
|
||||
"negative_prompt": [negative_prompt],
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
final_frames = None
|
||||
for cur_progress, frames, finished in tqdm(model.run(args, stream_results=True), total=num_inference_steps + 1):
|
||||
final_frames = frames
|
||||
|
||||
assert isinstance(final_frames, np.ndarray)
|
||||
assert final_frames.dtype == np.float32
|
||||
|
||||
final_frames = rearrange(final_frames, "t b h w c -> b t h w c")
|
||||
final_frames = final_frames[0]
|
||||
|
||||
os.makedirs("outputs", exist_ok=True)
|
||||
output_path = os.path.join("outputs", f"output_{int(time.time())}.mp4")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
frame_paths = []
|
||||
for i, frame in enumerate(final_frames):
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
frame_img = Image.fromarray(frame)
|
||||
frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png")
|
||||
frame_img.save(frame_path)
|
||||
frame_paths.append(frame_path)
|
||||
|
||||
frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
|
||||
ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}"
|
||||
os.system(ffmpeg_cmd)
|
||||
|
||||
json_path = os.path.splitext(output_path)[0] + ".json"
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(args, f, indent=4)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--prompt", default="""
|
||||
a high-motion drone POV flying at high speed through a vast desert environment, with dynamic camera movements capturing sweeping sand dunes,
|
||||
rocky terrain, and the occasional dry brush. The camera smoothly glides over the rugged landscape, weaving between towering rock formations and
|
||||
diving low across the sand. As the drone zooms forward, the motion gradually slows down, shifting into a close-up, hyper-detailed shot of a spider
|
||||
resting on a sunlit rock. The scene emphasizes cinematic motion, natural lighting, and intricate texture details on both the rock and the spider’s body,
|
||||
with a shallow depth of field to focus on the fine details of the spider’s legs and the rough surface beneath it. The atmosphere should feel immersive and alive,
|
||||
with the wind subtly blowing sand grains across the frame."""
|
||||
, required=False, help="Prompt for video generation.")
|
||||
@click.option(
|
||||
"--negative_prompt", default="", help="Negative prompt for video generation."
|
||||
)
|
||||
@click.option("--width", default=848, type=int, help="Width of the video.")
|
||||
@click.option("--height", default=480, type=int, help="Height of the video.")
|
||||
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
|
||||
@click.option("--seed", default=12345, type=int, help="Random seed.")
|
||||
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
|
||||
@click.option(
|
||||
"--num_steps", default=64, type=int, help="Number of inference steps."
|
||||
)
|
||||
@click.option("--model_dir", required=True, help="Path to the model directory.")
|
||||
def generate_cli(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
width,
|
||||
height,
|
||||
num_frames,
|
||||
seed,
|
||||
cfg_scale,
|
||||
num_steps,
|
||||
model_dir,
|
||||
):
|
||||
set_model_path(model_dir)
|
||||
output = generate_video(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
width,
|
||||
height,
|
||||
num_frames,
|
||||
seed,
|
||||
cfg_scale,
|
||||
num_steps,
|
||||
)
|
||||
click.echo(f"Video generated at: {output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_cli()
|
||||
0
mochi_preview/__init__.py
Normal file
0
mochi_preview/__init__.py
Normal file
BIN
mochi_preview/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
mochi_preview/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
mochi_preview/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc
Normal file
BIN
mochi_preview/__pycache__/t2v_synth_mochi.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc
Normal file
BIN
mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/__pycache__/utils.cpython-311.pyc
Normal file
BIN
mochi_preview/__pycache__/utils.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/__pycache__/utils.cpython-312.pyc
Normal file
BIN
mochi_preview/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
0
mochi_preview/dit/joint_model/__init__.py
Normal file
0
mochi_preview/dit/joint_model/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc
Normal file
BIN
mochi_preview/dit/joint_model/__pycache__/layers.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc
Normal file
BIN
mochi_preview/dit/joint_model/__pycache__/layers.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc
Normal file
BIN
mochi_preview/dit/joint_model/__pycache__/utils.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc
Normal file
BIN
mochi_preview/dit/joint_model/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
675
mochi_preview/dit/joint_model/asymm_models_joint.py
Normal file
675
mochi_preview/dit/joint_model/asymm_models_joint.py
Normal file
@ -0,0 +1,675 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from torch.nn.attention import sdpa_kernel
|
||||
|
||||
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
from .mod_rmsnorm import modulated_rmsnorm
|
||||
from .residual_tanh_gated_rmsnorm import (
|
||||
residual_tanh_gated_rmsnorm,
|
||||
)
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
modulate,
|
||||
pad_and_split_xy,
|
||||
unify_streams,
|
||||
)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
FLASH_ATTN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTN_IS_AVAILABLE = False
|
||||
|
||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_x: int,
|
||||
dim_y: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
update_y: bool = True,
|
||||
out_bias: bool = True,
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_feat_dim: Optional[int] = None,
|
||||
pooled_caption_mlp_bias: bool = True,
|
||||
use_transformer_engine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
self.dim_y = dim_y
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_x // num_heads
|
||||
self.attn_drop = attn_drop
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
)
|
||||
|
||||
# Input layers.
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device)
|
||||
# Project text features to match visual features (dim_y -> dim_x)
|
||||
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)
|
||||
self.proj_y = (
|
||||
nn.Linear(dim_x, dim_y, bias=out_bias, device=device)
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def run_qkv_y(self, y):
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
local_heads = self.num_heads // cp_size
|
||||
|
||||
if is_cp_active():
|
||||
# Only predict local heads.
|
||||
assert not self.qkv_bias
|
||||
W_qkv_y = self.qkv_y.weight.view(
|
||||
3, self.num_heads, self.head_dim, self.dim_y
|
||||
)
|
||||
W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
|
||||
W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
|
||||
qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
|
||||
else:
|
||||
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
||||
|
||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
||||
return q_y, k_y, v_y
|
||||
|
||||
def prepare_qkv(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
*,
|
||||
scale_x: torch.Tensor,
|
||||
scale_y: torch.Tensor,
|
||||
rope_cos: torch.Tensor,
|
||||
rope_sin: torch.Tensor,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
#print("x in attn", x.dtype, x.device)
|
||||
|
||||
# Process visual features
|
||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
assert qkv_x.dtype == torch.bfloat16
|
||||
qkv_x = all_to_all_collect_tokens(
|
||||
qkv_x, self.num_heads
|
||||
) # (3, B, N, local_h, head_dim)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
||||
#print("y in attn", y.dtype, y.device)
|
||||
#print(q_y.dtype, q_y.device)
|
||||
#print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device)
|
||||
# self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype)
|
||||
# self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype)
|
||||
# self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype)
|
||||
# self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype)
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
# Unite streams
|
||||
qkv = unify_streams(
|
||||
q_x,
|
||||
k_x,
|
||||
v_x,
|
||||
q_y,
|
||||
k_y,
|
||||
v_y,
|
||||
valid_token_indices,
|
||||
)
|
||||
|
||||
return qkv
|
||||
|
||||
@torch.compiler.disable()
|
||||
def run_attention(
|
||||
self,
|
||||
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
*,
|
||||
B: int,
|
||||
L: int,
|
||||
M: int,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen_in_batch: int,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
_, cp_size = get_cp_rank_size()
|
||||
N = cp_size * M
|
||||
assert self.num_heads % cp_size == 0
|
||||
local_heads = self.num_heads // cp_size
|
||||
local_dim = local_heads * self.head_dim
|
||||
total = qkv.size(0)
|
||||
|
||||
if FLASH_ATTN_IS_AVAILABLE:
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
) # (total, local_heads, head_dim)
|
||||
out = out.view(total, local_dim)
|
||||
else:
|
||||
raise NotImplementedError("Flash attention is currently required.")
|
||||
print("qkv: ",qkv.shape, qkv.dtype, qkv.device)
|
||||
expected_size = 2 * 44520 * 3 * 24 * 128
|
||||
actual_size = qkv.numel()
|
||||
print(f"Expected size: {expected_size}, Actual size: {actual_size}")
|
||||
q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim)
|
||||
|
||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
||||
assert x.size() == (B, N, local_dim)
|
||||
assert y.size() == (B, L, local_dim)
|
||||
|
||||
x = x.view(B, N, local_heads, self.head_dim)
|
||||
x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
|
||||
x = self.proj_x(x) # (B, M, dim_x)
|
||||
|
||||
if is_cp_active():
|
||||
y = all_gather(y) # (cp_size * B, L, local_heads * head_dim)
|
||||
y = rearrange(
|
||||
y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim
|
||||
) # (B, L, dim_x)
|
||||
y = self.proj_y(y) # (B, L, dim_y)
|
||||
return x, y
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
*,
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass of asymmetric multi-modal attention.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim_x) tensor for visual tokens
|
||||
y: (B, L, dim_y) tensor of text token features
|
||||
packed_indices: Dict with keys for Flash Attention
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
|
||||
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
||||
"""
|
||||
B, L, _ = y.shape
|
||||
_, M, _ = x.shape
|
||||
|
||||
# Predict a packed QKV tensor from visual and text features.
|
||||
# Don't checkpoint the all_to_all.
|
||||
qkv = self.prepare_qkv(
|
||||
x=x,
|
||||
y=y,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
rope_cos=rope_rotation.get("rope_cos"),
|
||||
rope_sin=rope_rotation.get("rope_sin"),
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
) # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
|
||||
x, y = self.run_attention(
|
||||
qkv,
|
||||
B=B,
|
||||
L=L,
|
||||
M=M,
|
||||
cu_seqlens=packed_indices["cu_seqlens_kv"],
|
||||
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
)
|
||||
return x, y
|
||||
|
||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size_x: int,
|
||||
hidden_size_y: int,
|
||||
num_heads: int,
|
||||
*,
|
||||
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
|
||||
if self.update_y:
|
||||
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
||||
else:
|
||||
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
||||
assert mlp_hidden_dim_x == int(1536 * 8)
|
||||
self.mlp_x = FeedForward(
|
||||
in_features=hidden_size_x,
|
||||
hidden_size=mlp_hidden_dim_x,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# MLP for text not needed in last block.
|
||||
if self.update_y:
|
||||
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
||||
self.mlp_y = FeedForward(
|
||||
in_features=hidden_size_y,
|
||||
hidden_size=mlp_hidden_dim_y,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim) tensor of visual tokens
|
||||
c: (B, dim) tensor of conditioned features
|
||||
y: (B, L, dim) tensor of text tokens
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim) tensor of visual tokens after block
|
||||
y: (B, L, dim) tensor of text tokens after block
|
||||
"""
|
||||
N = x.size(1)
|
||||
|
||||
c = F.silu(c)
|
||||
mod_x = self.mod_x(c)
|
||||
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
||||
|
||||
mod_y = self.mod_y(c)
|
||||
if self.update_y:
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
y,
|
||||
scale_x=scale_msa_x,
|
||||
scale_y=scale_msa_y,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
|
||||
|
||||
#@torch.compile(disable=not COMPILE_FINAL_LAYER)
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, device=device
|
||||
)
|
||||
self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, device=device
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
|
||||
Ingests text embeddings instead of a label.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size_x=1152,
|
||||
hidden_size_y=1152,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio_x=8.0,
|
||||
mlp_ratio_y=4.0,
|
||||
t5_feat_dim: int = 4096,
|
||||
t5_token_length: int = 256,
|
||||
patch_embed_bias: bool = True,
|
||||
timestep_mlp_bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
use_extended_posenc: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.head_dim = (
|
||||
hidden_size_x // num_heads
|
||||
) # Head dimension and count is determined by visual.
|
||||
self.use_extended_posenc = use_extended_posenc
|
||||
self.t5_token_length = t5_token_length
|
||||
self.t5_feat_dim = t5_feat_dim
|
||||
self.rope_theta = (
|
||||
rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
||||
)
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size_x,
|
||||
bias=patch_embed_bias,
|
||||
device=device,
|
||||
)
|
||||
# Conditionings
|
||||
# Timestep
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale
|
||||
)
|
||||
|
||||
# Caption Pooling (T5)
|
||||
self.t5_y_embedder = AttentionPool(
|
||||
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device
|
||||
)
|
||||
|
||||
# Dense Embedding Projection (T5)
|
||||
self.t5_yproj = nn.Linear(
|
||||
t5_feat_dim, hidden_size_y, bias=True, device=device
|
||||
)
|
||||
|
||||
# Initialize pos_frequencies as an empty parameter.
|
||||
self.pos_frequencies = nn.Parameter(
|
||||
torch.empty(3, self.num_heads, self.head_dim // 2, device=device)
|
||||
)
|
||||
|
||||
# for depth 48:
|
||||
# b = 0: AsymmetricJointBlock, update_y=True
|
||||
# b = 1: AsymmetricJointBlock, update_y=True
|
||||
# ...
|
||||
# b = 46: AsymmetricJointBlock, update_y=True
|
||||
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
||||
blocks = []
|
||||
for b in range(depth):
|
||||
# Joint multi-modal block
|
||||
update_y = b < depth - 1
|
||||
block = AsymmetricJointBlock(
|
||||
hidden_size_x,
|
||||
hidden_size_y,
|
||||
num_heads,
|
||||
mlp_ratio_x=mlp_ratio_x,
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size_x, patch_size, self.out_channels, device=device
|
||||
)
|
||||
|
||||
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C=12, T, H, W) tensor of visual tokens
|
||||
|
||||
Returns:
|
||||
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
||||
"""
|
||||
return self.x_embedder(x) # Convert BcTHW to BCN
|
||||
|
||||
#@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
||||
def prepare(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
t5_feat: torch.Tensor,
|
||||
t5_mask: torch.Tensor,
|
||||
):
|
||||
"""Prepare input and conditioning embeddings."""
|
||||
#("X", x.shape)
|
||||
with torch.profiler.record_function("x_emb_pe"):
|
||||
# Visual patch embeddings with positional encoding.
|
||||
T, H, W = x.shape[-3:]
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
with torch.profiler.record_function("rope_cis"):
|
||||
# Construct position array of size [N, 3].
|
||||
# pos[:, 0] is the frame index for each location,
|
||||
# pos[:, 1] is the row index for each location, and
|
||||
# pos[:, 2] is the column index for each location.
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
assert x.size(1) == N
|
||||
pos = create_position_matrix(
|
||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
||||
) # (N, 3)
|
||||
rope_cos, rope_sin = compute_mixed_rotation(
|
||||
freqs=self.pos_frequencies, pos=pos
|
||||
) # Each are (N, num_heads, dim // 2)
|
||||
|
||||
with torch.profiler.record_function("t_emb"):
|
||||
# Global vector embedding for conditionings.
|
||||
c_t = self.t_embedder(1 - sigma) # (B, D)
|
||||
|
||||
with torch.profiler.record_function("t5_pool"):
|
||||
# Pool T5 tokens using attention pooler
|
||||
# Note y_feat[1] contains T5 token features.
|
||||
# print("B", B)
|
||||
# print("t5 feat shape",t5_feat.shape)
|
||||
# print("t5 mask shape", t5_mask.shape)
|
||||
assert (
|
||||
t5_feat.size(1) == self.t5_token_length
|
||||
), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
|
||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||
assert (
|
||||
t5_y_pool.size(0) == B
|
||||
), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
|
||||
|
||||
c = c_t + t5_y_pool
|
||||
|
||||
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
||||
|
||||
return x, c, y_feat, rope_cos, rope_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
y_feat: List[torch.Tensor],
|
||||
y_mask: List[torch.Tensor],
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
):
|
||||
"""Forward pass of DiT.
|
||||
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat[0], y_mask[0]
|
||||
)
|
||||
del y_mask
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
N = x.size(1)
|
||||
M = N // cp_size
|
||||
assert (
|
||||
N % cp_size == 0
|
||||
), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
|
||||
|
||||
if cp_size > 1:
|
||||
x = x.narrow(1, cp_rank * M, M)
|
||||
|
||||
assert self.num_heads % cp_size == 0
|
||||
local_heads = self.num_heads // cp_size
|
||||
rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
|
||||
rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
c,
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
packed_indices=packed_indices,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
|
||||
patch = x.size(2)
|
||||
x = all_gather(x)
|
||||
x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
|
||||
return x
|
||||
163
mochi_preview/dit/joint_model/context_parallel.py
Normal file
163
mochi_preview/dit/joint_model/context_parallel.py
Normal file
@ -0,0 +1,163 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
_CONTEXT_PARALLEL_GROUP = None
|
||||
_CONTEXT_PARALLEL_RANK = None
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = None
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = None
|
||||
|
||||
|
||||
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return x
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
return x.tensor_split(cp_size, dim=dim)[cp_rank]
|
||||
|
||||
|
||||
def set_cp_group(cp_group, ranks, global_rank):
|
||||
global \
|
||||
_CONTEXT_PARALLEL_GROUP, \
|
||||
_CONTEXT_PARALLEL_RANK, \
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE, \
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS
|
||||
if _CONTEXT_PARALLEL_GROUP is not None:
|
||||
raise RuntimeError("CP group already initialized.")
|
||||
_CONTEXT_PARALLEL_GROUP = cp_group
|
||||
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
|
||||
|
||||
assert (
|
||||
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
|
||||
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
|
||||
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
|
||||
ranks
|
||||
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
|
||||
|
||||
|
||||
def get_cp_group():
|
||||
if _CONTEXT_PARALLEL_GROUP is None:
|
||||
raise RuntimeError("CP group not initialized")
|
||||
return _CONTEXT_PARALLEL_GROUP
|
||||
|
||||
|
||||
def is_cp_active():
|
||||
return _CONTEXT_PARALLEL_GROUP is not None
|
||||
|
||||
|
||||
def get_cp_rank_size():
|
||||
if _CONTEXT_PARALLEL_GROUP:
|
||||
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
|
||||
else:
|
||||
return 0, 1
|
||||
|
||||
|
||||
class AllGatherIntoTensorFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
|
||||
ctx.reduce_dtype = reduce_dtype
|
||||
ctx.group = group
|
||||
ctx.batch_size = x.size(0)
|
||||
group_size = dist.get_world_size(group)
|
||||
|
||||
x = x.contiguous()
|
||||
output = torch.empty(
|
||||
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
dist.all_gather_into_tensor(output, x, group=group)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return tensor
|
||||
|
||||
return AllGatherIntoTensorFunction.apply(
|
||||
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def _all_to_all_single(output, input, group):
|
||||
# Disable compilation since torch compile changes contiguity.
|
||||
assert input.is_contiguous(), "Input tensor must be contiguous."
|
||||
assert output.is_contiguous(), "Output tensor must be contiguous."
|
||||
return dist.all_to_all_single(output, input, group=group)
|
||||
|
||||
|
||||
class CollectTokens(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
|
||||
"""Redistribute heads and receive tokens.
|
||||
|
||||
Args:
|
||||
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
|
||||
|
||||
Returns:
|
||||
qkv: shape: [3, B, N, local_heads, head_dim]
|
||||
|
||||
where M is the number of local tokens,
|
||||
N = cp_size * M is the number of global tokens,
|
||||
local_heads = num_heads // cp_size is the number of local heads.
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.num_heads = num_heads
|
||||
cp_size = dist.get_world_size(group)
|
||||
assert num_heads % cp_size == 0
|
||||
ctx.local_heads = num_heads // cp_size
|
||||
|
||||
qkv = rearrange(
|
||||
qkv,
|
||||
"B M (qkv G h d) -> G M h B (qkv d)",
|
||||
qkv=3,
|
||||
G=cp_size,
|
||||
h=ctx.local_heads,
|
||||
).contiguous()
|
||||
|
||||
output_chunks = torch.empty_like(qkv)
|
||||
_all_to_all_single(output_chunks, qkv, group=group)
|
||||
|
||||
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
|
||||
|
||||
|
||||
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Move QKV dimension to the front.
|
||||
# B M (3 H d) -> 3 B M H d
|
||||
B, M, _ = x.size()
|
||||
x = x.view(B, M, 3, num_heads, -1)
|
||||
return x.permute(2, 0, 1, 3, 4)
|
||||
|
||||
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
|
||||
|
||||
|
||||
class CollectHeads(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
|
||||
"""Redistribute tokens and receive heads.
|
||||
|
||||
Args:
|
||||
x: Output of attention. Shape: [B, N, local_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Shape: [B, M, num_heads * head_dim]
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.local_heads = x.size(2)
|
||||
ctx.head_dim = x.size(3)
|
||||
group_size = dist.get_world_size(group)
|
||||
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
|
||||
output = torch.empty_like(x)
|
||||
_all_to_all_single(output, x, group=group)
|
||||
del x
|
||||
return rearrange(output, "G h M B D -> B M (G h D)")
|
||||
|
||||
|
||||
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Merge heads.
|
||||
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||
|
||||
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
|
||||
178
mochi_preview/dit/joint_model/layers.py
Normal file
178
mochi_preview/dit/joint_model/layers.py
Normal file
@ -0,0 +1,178 @@
|
||||
import collections.abc
|
||||
import math
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
frequency_embedding_size: int = 256,
|
||||
*,
|
||||
bias: bool = True,
|
||||
timestep_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.timestep_scale = timestep_scale
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
||||
freqs.mul_(-math.log(max_period) / half).exp_()
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PooledCaptionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
caption_feature_dim: int,
|
||||
hidden_size: int,
|
||||
*,
|
||||
bias: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.caption_feature_dim = caption_feature_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# keep parameter count and computation constant compared to standard FFN
|
||||
hidden_size = int(2 * hidden_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
||||
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device)
|
||||
self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.flatten = flatten
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
)
|
||||
assert norm_layer is None
|
||||
self.norm = (
|
||||
norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, _C, T, H, W = x.shape
|
||||
if not self.dynamic_img_pad:
|
||||
assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
else:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
#print("x",x.dtype, x.device)
|
||||
#print(self.proj.weight.dtype, self.proj.weight.device)
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
if not self.flatten:
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
x_fp32 = x.float()
|
||||
x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (x_normed * self.weight).type_as(x)
|
||||
23
mochi_preview/dit/joint_model/mod_rmsnorm.py
Normal file
23
mochi_preview/dit/joint_model/mod_rmsnorm.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ModulatedRMSNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale, eps=1e-6):
|
||||
# Convert to fp32 for precision
|
||||
x_fp32 = x.float()
|
||||
scale_fp32 = scale.float()
|
||||
|
||||
# Compute RMS
|
||||
mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
|
||||
inv_rms = torch.rsqrt(mean_square + eps)
|
||||
|
||||
# Normalize and modulate
|
||||
x_normed = x_fp32 * inv_rms
|
||||
x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
|
||||
|
||||
return x_modulated.type_as(x)
|
||||
|
||||
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
return ModulatedRMSNorm.apply(x, scale, eps)
|
||||
27
mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py
Normal file
27
mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
|
||||
class ResidualTanhGatedRMSNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, x_res, gate, eps=1e-6):
|
||||
# Convert to fp32 for precision
|
||||
x_res_fp32 = x_res.float()
|
||||
|
||||
# Compute RMS
|
||||
mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True)
|
||||
scale = torch.rsqrt(mean_square + eps)
|
||||
|
||||
# Apply tanh to gate
|
||||
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
||||
|
||||
# Normalize and apply gated scaling
|
||||
x_normed = x_res_fp32 * scale * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed.type_as(x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps)
|
||||
88
mochi_preview/dit/joint_model/rope_mixed.py
Normal file
88
mochi_preview/dit/joint_model/rope_mixed.py
Normal file
@ -0,0 +1,88 @@
|
||||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def centers(start: float, stop, num, dtype=None, device=None):
|
||||
"""linspace through bin centers.
|
||||
|
||||
Args:
|
||||
start (float): Start of the range.
|
||||
stop (float): End of the range.
|
||||
num (int): Number of points.
|
||||
dtype (torch.dtype): Data type of the points.
|
||||
device (torch.device): Device of the points.
|
||||
|
||||
Returns:
|
||||
centers (Tensor): Centers of the bins. Shape: (num,).
|
||||
"""
|
||||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
pW: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
target_area: float = 36864,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
T: int - Temporal dimension
|
||||
pH: int - Height dimension after patchify
|
||||
pW: int - Width dimension after patchify
|
||||
|
||||
Returns:
|
||||
pos: [T * pH * pW, 3] - position matrix
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Create 1D tensors for each dimension
|
||||
t = torch.arange(T, dtype=dtype)
|
||||
|
||||
# Positionally interpolate to area 36864.
|
||||
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
||||
# This automatically scales rope positions when the resolution changes.
|
||||
# We use a large target area so the model is more sensitive
|
||||
# to changes in the learned pos_frequencies matrix.
|
||||
scale = math.sqrt(target_area / (pW * pH))
|
||||
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
||||
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
||||
|
||||
# Use meshgrid to create 3D grids
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
freqs: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
||||
|
||||
Args:
|
||||
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
||||
pos: [N, 3] - position of each token
|
||||
num_heads: int
|
||||
|
||||
Returns:
|
||||
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
||||
freqs_sin: [N, num_heads, num_freqs] - sine components
|
||||
"""
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
assert freqs.ndim == 3
|
||||
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs_sum)
|
||||
freqs_sin = torch.sin(freqs_sum)
|
||||
return freqs_cos, freqs_sin
|
||||
34
mochi_preview/dit/joint_model/temporal_rope.py
Normal file
34
mochi_preview/dit/joint_model/temporal_rope.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Based on Llama3 Implementation.
|
||||
import torch
|
||||
|
||||
|
||||
def apply_rotary_emb_qk_real(
|
||||
xqk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
||||
|
||||
Args:
|
||||
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
||||
Can be either just query or just key, or both stacked along some batch or * dim.
|
||||
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
||||
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The input tensor with rotary embeddings applied.
|
||||
"""
|
||||
assert xqk.dtype == torch.bfloat16
|
||||
# Split the last dimension into even and odd parts
|
||||
xqk_even = xqk[..., 0::2]
|
||||
xqk_odd = xqk[..., 1::2]
|
||||
|
||||
# Apply rotation
|
||||
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
assert out.dtype == torch.bfloat16
|
||||
return out
|
||||
189
mochi_preview/dit/joint_model/utils.py
Normal file
189
mochi_preview/dit/joint_model/utils.py
Normal file
@ -0,0 +1,189 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
||||
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
||||
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class PadSplitXY(torch.autograd.Function):
|
||||
"""
|
||||
Merge heads, pad and extract visual and text tokens,
|
||||
and split along the sequence length.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
xy: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
B: int,
|
||||
N: int,
|
||||
L: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
|
||||
indices: Valid token indices out of unpacked tensor. Shape: (total,)
|
||||
|
||||
Returns:
|
||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
|
||||
y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.B, ctx.N, ctx.L = B, N, L
|
||||
D = xy.size(1)
|
||||
|
||||
# Pad sequences to (B, N + L, dim).
|
||||
assert indices.ndim == 1
|
||||
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
||||
indices = indices.unsqueeze(1).expand(
|
||||
-1, D
|
||||
) # (total,) -> (total, num_heads * head_dim)
|
||||
output.scatter_(0, indices, xy)
|
||||
xy = output.view(B, N + L, D)
|
||||
|
||||
# Split visual and text tokens along the sequence length.
|
||||
return torch.tensor_split(xy, (N,), dim=1)
|
||||
|
||||
|
||||
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
|
||||
|
||||
|
||||
class UnifyStreams(torch.autograd.Function):
|
||||
"""Unify visual and text streams."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q_x: torch.Tensor,
|
||||
k_x: torch.Tensor,
|
||||
v_x: torch.Tensor,
|
||||
q_y: torch.Tensor,
|
||||
k_y: torch.Tensor,
|
||||
v_y: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_x: (B, N, num_heads, head_dim)
|
||||
k_x: (B, N, num_heads, head_dim)
|
||||
v_x: (B, N, num_heads, head_dim)
|
||||
q_y: (B, L, num_heads, head_dim)
|
||||
k_y: (B, L, num_heads, head_dim)
|
||||
v_y: (B, L, num_heads, head_dim)
|
||||
indices: (total <= B * (N + L))
|
||||
|
||||
Returns:
|
||||
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
B, N, num_heads, head_dim = q_x.size()
|
||||
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
|
||||
D = num_heads * head_dim
|
||||
|
||||
q = torch.cat([q_x, q_y], dim=1)
|
||||
k = torch.cat([k_x, k_y], dim=1)
|
||||
v = torch.cat([v_x, v_y], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
|
||||
|
||||
indices = indices[:, None, None].expand(-1, 3, D)
|
||||
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
|
||||
return qkv.unflatten(2, (num_heads, head_dim))
|
||||
|
||||
|
||||
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
|
||||
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
|
||||
445
mochi_preview/t2v_synth_mochi.py
Normal file
445
mochi_preview/t2v_synth_mochi.py
Normal file
@ -0,0 +1,445 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from safetensors.torch import load_file
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import yaml
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import (
|
||||
BackwardPrefetch,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
lambda_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from transformers.models.t5.modeling_t5 import T5Block
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from .utils import Timer
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
T5_MODEL = "weights/T5"
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
class T5_Tokenizer:
|
||||
"""Wrapper around Hugging Face tokenizer for T5
|
||||
|
||||
Args:
|
||||
model_name(str): Name of tokenizer to load.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
|
||||
|
||||
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
|
||||
"""
|
||||
Args:
|
||||
prompt (str): The input text to tokenize.
|
||||
padding (str): The padding strategy.
|
||||
truncation (bool): Flag indicating whether to truncate the tokens.
|
||||
return_tensors (str): Flag indicating whether to return tensors.
|
||||
max_length (int): The max length of the tokens.
|
||||
"""
|
||||
assert (
|
||||
not max_length or max_length == MAX_T5_TOKEN_LENGTH
|
||||
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
|
||||
|
||||
tokenized_output = self.tokenizer(
|
||||
prompt,
|
||||
padding=padding,
|
||||
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
|
||||
truncation=truncation,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
return tokenized_output
|
||||
|
||||
|
||||
def unnormalize_latents(
|
||||
z: torch.Tensor,
|
||||
mean: torch.Tensor,
|
||||
std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Unnormalize latents. Useful for decoding DiT samples.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
|
||||
|
||||
Returns:
|
||||
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
|
||||
"""
|
||||
mean = mean[:, None, None, None]
|
||||
std = std[:, None, None, None]
|
||||
|
||||
assert z.ndim == 5
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
|
||||
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
),
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
device_id=device_id,
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
return model
|
||||
|
||||
|
||||
def compute_packed_indices(
|
||||
N: int,
|
||||
text_mask: List[torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
|
||||
|
||||
Args:
|
||||
N: Number of visual tokens.
|
||||
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
|
||||
|
||||
Returns:
|
||||
packed_indices: Dict with keys for Flash Attention:
|
||||
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
|
||||
in the packed sequence.
|
||||
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
|
||||
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
|
||||
"""
|
||||
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
||||
assert N > 0 and len(text_mask) == 1
|
||||
text_mask = text_mask[0]
|
||||
|
||||
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
|
||||
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
||||
valid_token_indices = torch.nonzero(
|
||||
mask.flatten(), as_tuple=False
|
||||
).flatten() # up to (B * (N + L),)
|
||||
|
||||
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
|
||||
return {
|
||||
"cu_seqlens_kv": cu_seqlens,
|
||||
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
|
||||
"valid_token_indices_kv": valid_token_indices,
|
||||
}
|
||||
|
||||
|
||||
def shift_sigma(
|
||||
sigma: np.ndarray,
|
||||
shift: float,
|
||||
):
|
||||
"""Shift noise standard deviation toward higher values.
|
||||
|
||||
Useful for training a model at high resolutions,
|
||||
or sampling more finely at high noise levels.
|
||||
|
||||
Equivalent to:
|
||||
sigma_shift = shift / (shift + 1 / sigma - 1)
|
||||
except for sigma = 0.
|
||||
|
||||
Args:
|
||||
sigma: noise standard deviation in [0, 1]
|
||||
shift: shift factor >= 1.
|
||||
For shift > 1, shifts sigma to higher values.
|
||||
For shift = 1, identity function.
|
||||
"""
|
||||
return shift * sigma / (shift * sigma + 1 - sigma)
|
||||
|
||||
|
||||
class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device_id: int,
|
||||
vae_stats_path: str,
|
||||
dit_checkpoint_path: str,
|
||||
):
|
||||
super().__init__()
|
||||
t = Timer()
|
||||
self.device = torch.device(device_id)
|
||||
|
||||
#self.t5_tokenizer = T5_Tokenizer()
|
||||
|
||||
# with t("load_text_encs"):
|
||||
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
|
||||
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
|
||||
|
||||
with t("construct_dit"):
|
||||
from .dit.joint_model.asymm_models_joint import (
|
||||
AsymmDiTJoint,
|
||||
)
|
||||
model: nn.Module = torch.nn.utils.skip_init(
|
||||
AsymmDiTJoint,
|
||||
depth=48,
|
||||
patch_size=2,
|
||||
num_heads=24,
|
||||
hidden_size_x=3072,
|
||||
hidden_size_y=1536,
|
||||
mlp_ratio_x=4.0,
|
||||
mlp_ratio_y=4.0,
|
||||
in_channels=12,
|
||||
qk_norm=True,
|
||||
qkv_bias=False,
|
||||
out_bias=True,
|
||||
patch_embed_bias=True,
|
||||
timestep_mlp_bias=True,
|
||||
timestep_scale=1000.0,
|
||||
t5_feat_dim=4096,
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
)
|
||||
with t("dit_load_checkpoint"):
|
||||
|
||||
model.load_state_dict(load_file(dit_checkpoint_path))
|
||||
|
||||
with t("fsdp_dit"):
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
for name, param in self.dit.named_parameters():
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
|
||||
|
||||
vae_stats = json.load(open(vae_stats_path))
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
t.print_stats()
|
||||
|
||||
def get_conditioning(self, prompts, *, zero_last_n_prompts: int):
|
||||
B = len(prompts)
|
||||
print(f"Getting conditioning for {B} prompts")
|
||||
assert (
|
||||
0 <= zero_last_n_prompts <= B
|
||||
), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}"
|
||||
tokenize_kwargs = dict(
|
||||
prompt=prompts,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH)
|
||||
caption_input_ids_t5 = t5_toks["input_ids"]
|
||||
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
|
||||
del t5_toks
|
||||
|
||||
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
||||
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
||||
|
||||
if zero_last_n_prompts > 0:
|
||||
# Zero the last N prompts
|
||||
caption_input_ids_t5[-zero_last_n_prompts:] = 0
|
||||
caption_attention_mask_t5[-zero_last_n_prompts:] = False
|
||||
|
||||
caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True)
|
||||
caption_attention_mask_t5 = caption_attention_mask_t5.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
y_mask = [caption_attention_mask_t5]
|
||||
y_feat = []
|
||||
|
||||
self.t5_enc.to(self.device)
|
||||
y_feat.append(
|
||||
self.t5_enc(
|
||||
caption_input_ids_t5, caption_attention_mask_t5
|
||||
).last_hidden_state.detach().to(torch.float32)
|
||||
)
|
||||
print(y_feat.shape)
|
||||
print(y_feat[0])
|
||||
self.t5_enc.to("cpu")
|
||||
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
||||
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
||||
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
||||
return dict(y_mask=y_mask, y_feat=y_feat)
|
||||
|
||||
def get_packed_indices(self, y_mask, *, lT, lW, lH):
|
||||
patch_size = 2
|
||||
N = lT * lH * lW // (patch_size**2)
|
||||
assert len(y_mask) == 1
|
||||
packed_indices = compute_packed_indices(N, y_mask)
|
||||
self.move_to_device_(packed_indices)
|
||||
return packed_indices
|
||||
|
||||
def move_to_device_(self, sample):
|
||||
if isinstance(sample, dict):
|
||||
for key in sample.keys():
|
||||
if isinstance(sample[key], torch.Tensor):
|
||||
sample[key] = sample[key].to(self.device, non_blocking=True)
|
||||
|
||||
@torch.inference_mode(mode=True)
|
||||
def run(self, args, stream_results):
|
||||
random.seed(args["seed"])
|
||||
np.random.seed(args["seed"])
|
||||
torch.manual_seed(args["seed"])
|
||||
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(args["seed"])
|
||||
|
||||
# assert (
|
||||
# len(args["prompt"]) == 1
|
||||
# ), f"Expected exactly one prompt, got {len(args['prompt'])}"
|
||||
#prompt = args["prompt"][0]
|
||||
#neg_prompt = args["negative_prompt"][0] if len(args["negative_prompt"]) else ""
|
||||
B = 1
|
||||
|
||||
w = args["width"]
|
||||
h = args["height"]
|
||||
t = args["num_frames"]
|
||||
batch_cfg = args["mochi_args"]["batch_cfg"]
|
||||
sample_steps = args["mochi_args"]["num_inference_steps"]
|
||||
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
||||
assert (
|
||||
len(cfg_schedule) == sample_steps
|
||||
), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}"
|
||||
sigma_schedule = args["mochi_args"].get("sigma_schedule")
|
||||
if sigma_schedule:
|
||||
assert (
|
||||
len(sigma_schedule) == sample_steps + 1
|
||||
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
||||
assert (t - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {t - 1}"
|
||||
|
||||
# if batch_cfg:
|
||||
# sample_batched = self.get_conditioning(
|
||||
# [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0
|
||||
# )
|
||||
# else:
|
||||
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
|
||||
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
|
||||
|
||||
spatial_downsample = 8
|
||||
temporal_downsample = 6
|
||||
latent_t = (t - 1) // temporal_downsample + 1
|
||||
latent_w, latent_h = w // spatial_downsample, h // spatial_downsample
|
||||
|
||||
latent_dims = dict(lT=latent_t, lW=latent_w, lH=latent_h)
|
||||
in_channels = 12
|
||||
z = torch.randn(
|
||||
(B, in_channels, latent_t, latent_h, latent_w),
|
||||
device=self.device,
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# if batch_cfg:
|
||||
# sample_batched["packed_indices"] = self.get_packed_indices(
|
||||
# sample_batched["y_mask"], **latent_dims
|
||||
# )
|
||||
# z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
|
||||
# else:
|
||||
|
||||
sample = {
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
|
||||
# print(sample["y_mask"])
|
||||
# print(type(sample["y_mask"]))
|
||||
# print(sample["y_mask"][0].shape)
|
||||
|
||||
# print(sample["y_feat"])
|
||||
# print(type(sample["y_feat"]))
|
||||
# print(sample["y_feat"][0].shape)
|
||||
|
||||
print(sample_null["y_mask"])
|
||||
print(type(sample_null["y_mask"]))
|
||||
print(sample_null["y_mask"][0].shape)
|
||||
|
||||
print(sample_null["y_feat"])
|
||||
print(type(sample_null["y_feat"]))
|
||||
print(sample_null["y_feat"][0].shape)
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
#print("z", z.dtype, z.device)
|
||||
#print("sigma", sigma.dtype, sigma.device)
|
||||
self.dit.to(self.device)
|
||||
# if batch_cfg:
|
||||
# with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
# out = self.dit(z, sigma, **sample_batched)
|
||||
# out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
#else:
|
||||
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
assert out_cond.shape == out_uncond.shape
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
||||
sigma = sigma_schedule[i]
|
||||
dsigma = sigma - sigma_schedule[i + 1]
|
||||
|
||||
# `pred` estimates `z_0 - eps`.
|
||||
pred, output_cond = model_fn(
|
||||
z=z,
|
||||
sigma=torch.full(
|
||||
[B] if not batch_cfg else [B * 2], sigma, device=z.device
|
||||
),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
output_cond = output_cond.to(z)
|
||||
|
||||
#if stream_results:
|
||||
# yield i / sample_steps, None, False
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
if batch_cfg:
|
||||
z = z[:B]
|
||||
z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||
self.dit.to("cpu")
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
print("samples", samples.shape, samples.dtype, samples.device)
|
||||
return samples
|
||||
33
mochi_preview/utils.py
Normal file
33
mochi_preview/utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.times = {} # Dictionary to store times per stage
|
||||
|
||||
def __call__(self, name):
|
||||
print(f"Timing {name}")
|
||||
return self.TimerContextManager(self, name)
|
||||
|
||||
def print_stats(self):
|
||||
total_time = sum(self.times.values())
|
||||
# Print table header
|
||||
print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent"))
|
||||
for name, t in self.times.items():
|
||||
percent = (t / total_time) * 100 if total_time > 0 else 0
|
||||
print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent))
|
||||
|
||||
class TimerContextManager:
|
||||
def __init__(self, outer, name):
|
||||
self.outer = outer # Reference to the Timer instance
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
end_time = time.perf_counter()
|
||||
elapsed = end_time - self.start_time
|
||||
self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed
|
||||
0
mochi_preview/vae/__init__.py
Normal file
0
mochi_preview/vae/__init__.py
Normal file
BIN
mochi_preview/vae/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/vae/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/cp_conv.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/cp_conv.cpython-312.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/vae/__pycache__/model.cpython-311.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/model.cpython-311.pyc
Normal file
Binary file not shown.
BIN
mochi_preview/vae/__pycache__/model.cpython-312.pyc
Normal file
BIN
mochi_preview/vae/__pycache__/model.cpython-312.pyc
Normal file
Binary file not shown.
152
mochi_preview/vae/cp_conv.py
Normal file
152
mochi_preview/vae/cp_conv.py
Normal file
@ -0,0 +1,152 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that handles communication between ranks for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
frames_to_send: int, number of frames to communicate between ranks
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T', H, W)
|
||||
"""
|
||||
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
||||
if frames_to_send == 0 or cp_world_size == 1:
|
||||
return x
|
||||
|
||||
group = get_cp_group()
|
||||
global_rank = dist.get_rank()
|
||||
|
||||
# Send to next rank
|
||||
if cp_rank < cp_world_size - 1:
|
||||
assert x.size(2) >= frames_to_send
|
||||
tail = x[:, :, -frames_to_send:].contiguous()
|
||||
dist.send(tail, global_rank + 1, group=group)
|
||||
|
||||
# Receive from previous rank
|
||||
if cp_rank > 0:
|
||||
B, C, _, H, W = x.shape
|
||||
recv_buffer = torch.empty(
|
||||
(B, C, frames_to_send, H, W),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
dist.recv(recv_buffer, global_rank - 1, group=group)
|
||||
x = torch.cat([recv_buffer, x], dim=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
|
||||
if max_T > x.size(2):
|
||||
pad_T = max_T - x.size(2)
|
||||
pad_dims = (0, 0, 0, 0, 0, pad_T)
|
||||
return F.pad(x, pad_dims)
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Gathers all frames from all processes for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T_total, H, W)
|
||||
"""
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
cp_group = get_cp_group()
|
||||
|
||||
# Ensure the tensor is contiguous for collective operations
|
||||
x = x.contiguous()
|
||||
|
||||
# Get the local time dimension size
|
||||
local_T = x.size(2)
|
||||
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
|
||||
|
||||
# Gather all T sizes from all processes
|
||||
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
|
||||
dist.all_gather(all_T, local_T_tensor, group=cp_group)
|
||||
all_T = [t.item() for t in all_T]
|
||||
|
||||
# Pad the tensor at the end of the time dimension to match max_T
|
||||
max_T = max(all_T)
|
||||
x = _pad_to_max(x, max_T).contiguous()
|
||||
|
||||
# Prepare a list to hold the gathered tensors
|
||||
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
|
||||
|
||||
# Perform the all_gather operation
|
||||
dist.all_gather(gathered_x, x, group=cp_group)
|
||||
|
||||
# Slice each gathered tensor back to its original T size
|
||||
for idx, t_size in enumerate(all_T):
|
||||
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
|
||||
|
||||
return torch.cat(gathered_x, dim=2)
|
||||
|
||||
|
||||
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
|
||||
"""Estimate memory usage based on input tensor size and data type."""
|
||||
element_size = input.element_size() # Size in bytes of each element
|
||||
memory_bytes = input.numel() * element_size
|
||||
memory_gb = memory_bytes / 1024**3
|
||||
return memory_gb > max_gb
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(torch.nn.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
**kwargs,
|
||||
):
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
cp_rank, cp_world_size = get_cp_rank_size()
|
||||
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if cp_rank == 0:
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
|
||||
|
||||
if cp_world_size == 1:
|
||||
return super().forward(x)
|
||||
|
||||
if all(s == 1 for s in self.stride):
|
||||
# Receive some frames from previous rank.
|
||||
x = cp_pass_frames(x, context_size)
|
||||
return super().forward(x)
|
||||
|
||||
# Less efficient implementation for strided convs.
|
||||
# All gather x, infer and chunk.
|
||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
||||
x = super().forward(x)
|
||||
x_chunks = x.tensor_split(cp_world_size, dim=2)
|
||||
assert len(x_chunks) == cp_world_size
|
||||
return x_chunks[cp_rank]
|
||||
815
mochi_preview/vae/model.py
Normal file
815
mochi_preview/vae/model.py
Normal file
@ -0,0 +1,815 @@
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard
|
||||
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
class GroupNormSpatial(nn.GroupNorm):
|
||||
"""
|
||||
GroupNorm applied per-frame.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
|
||||
B, C, T, H, W = x.shape
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W")
|
||||
# Run group norm in chunks.
|
||||
output = torch.empty_like(x)
|
||||
for b in range(0, B * T, chunk_size):
|
||||
output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
|
||||
return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
|
||||
|
||||
|
||||
class SafeConv3d(torch.nn.Conv3d):
|
||||
"""
|
||||
NOTE: No support for padding along time dimension.
|
||||
Input must already be padded along time.
|
||||
"""
|
||||
|
||||
def forward(self, input):
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
if memory_count > 2:
|
||||
part_num = int(memory_count / 2) + 1
|
||||
k = self.kernel_size[0]
|
||||
input_idx = torch.arange(k - 1, input.size(2))
|
||||
input_chunks_idx = torch.chunk(input_idx, part_num, dim=0)
|
||||
|
||||
# assert self.kernel_size == (3, 3, 3), f"kernel_size {self.kernel_size} != (3, 3, 3)"
|
||||
assert self.stride[0] == 1, f"stride {self.stride}"
|
||||
assert self.dilation[0] == 1, f"dilation {self.dilation}"
|
||||
assert self.padding[0] == 0, f"padding {self.padding}"
|
||||
|
||||
# Comptue output size
|
||||
assert not input.requires_grad
|
||||
B, _, T_in, H_in, W_in = input.shape
|
||||
output_size = (
|
||||
B,
|
||||
self.out_channels,
|
||||
T_in - k + 1,
|
||||
H_in // self.stride[1],
|
||||
W_in // self.stride[2],
|
||||
)
|
||||
output = torch.empty(output_size, dtype=input.dtype, device=input.device)
|
||||
for input_chunk_idx in input_chunks_idx:
|
||||
input_s = input_chunk_idx[0] - k + 1
|
||||
input_e = input_chunk_idx[-1] + 1
|
||||
input_chunk = input[:, :, input_s:input_e, :, :]
|
||||
output_chunk = super(SafeConv3d, self).forward(input_chunk)
|
||||
|
||||
output_s = input_s
|
||||
output_e = output_s + output_chunk.size(2)
|
||||
output[:, :, output_s:output_e, :, :] = output_chunk
|
||||
|
||||
return output
|
||||
else:
|
||||
return super(SafeConv3d, self).forward(input)
|
||||
|
||||
|
||||
class StridedSafeConv3d(torch.nn.Conv3d):
|
||||
def forward(self, input, local_shard: bool = False):
|
||||
assert self.stride[0] == self.kernel_size[0]
|
||||
assert self.dilation[0] == 1
|
||||
assert self.padding[0] == 0
|
||||
|
||||
kernel_size = self.kernel_size[0]
|
||||
stride = self.stride[0]
|
||||
T_in = input.size(2)
|
||||
T_out = T_in // kernel_size
|
||||
|
||||
# Parallel implementation.
|
||||
if local_shard:
|
||||
idx = torch.arange(T_out)
|
||||
idx = local_shard(idx, dim=0)
|
||||
start = idx.min() * stride
|
||||
end = idx.max() * stride + kernel_size
|
||||
local_input = input[:, :, start:end, :, :]
|
||||
return torch.nn.Conv3d.forward(self, local_input)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ContextParallelConv3d(SafeConv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
causal: bool = True,
|
||||
context_parallel: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.causal = causal
|
||||
self.context_parallel = context_parallel
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
cp_rank, cp_world_size = get_cp_rank_size()
|
||||
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
pad_front = context_size
|
||||
pad_back = 0
|
||||
else:
|
||||
pad_front = context_size // 2
|
||||
pad_back = context_size - pad_front
|
||||
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
if self.context_parallel and cp_world_size == 1:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
else:
|
||||
if cp_rank == 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
elif cp_rank == cp_world_size - 1 and pad_back:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
|
||||
|
||||
if self.context_parallel and cp_world_size == 1:
|
||||
return super().forward(x)
|
||||
|
||||
if self.stride[0] == 1:
|
||||
# Receive some frames from previous rank.
|
||||
x = cp_pass_frames(x, context_size)
|
||||
return super().forward(x)
|
||||
|
||||
# Less efficient implementation for strided convs.
|
||||
# All gather x, infer and chunk.
|
||||
assert (
|
||||
x.dtype == torch.bfloat16
|
||||
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
|
||||
|
||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
||||
return StridedSafeConv3d.forward(self, x, local_shard=True)
|
||||
|
||||
|
||||
class Conv1x1(nn.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
|
||||
super().__init__(in_features, out_features, *args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, *] or [B, *, C].
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
temporal_expansion: int,
|
||||
spatial_expansion: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# When printed, this module should show the temporal and spatial expansion factors.
|
||||
def extra_repr(self):
|
||||
return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
|
||||
"""
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
|
||||
st=self.temporal_expansion,
|
||||
sh=self.spatial_expansion,
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
cp_rank, _ = get_cp_rank_size()
|
||||
if self.temporal_expansion > 1 and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
assert all(x.shape)
|
||||
x = x[:, :, self.temporal_expansion - 1 :]
|
||||
assert all(x.shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def norm_fn(
|
||||
in_channels: int,
|
||||
affine: bool = True,
|
||||
):
|
||||
return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block that preserves the spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
padding_mode: str = "replicate",
|
||||
causal: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
assert causal
|
||||
self.stack = nn.Sequential(
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
ContextParallelConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
ContextParallelConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
self.attn_block = attn_block if attn_block else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
residual = x
|
||||
x = self.stack(x)
|
||||
x = x + residual
|
||||
del residual
|
||||
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
def prepare_for_attention(qkv: torch.Tensor, head_dim: int, qk_norm: bool = True):
|
||||
"""Prepare qkv tensor for attention and normalize qk.
|
||||
|
||||
Args:
|
||||
qkv: Input tensor. Shape: [B, L, 3 * num_heads * head_dim].
|
||||
|
||||
Returns:
|
||||
q, k, v: qkv tensor split into q, k, v. Shape: [B, num_heads, L, head_dim].
|
||||
"""
|
||||
assert qkv.ndim == 3 # [B, L, 3 * num_heads * head_dim]
|
||||
assert qkv.size(2) % (3 * head_dim) == 0
|
||||
num_heads = qkv.size(2) // (3 * head_dim)
|
||||
qkv = qkv.unflatten(2, (3, num_heads, head_dim))
|
||||
|
||||
q, k, v = qkv.unbind(2) # [B, L, num_heads, head_dim]
|
||||
q = q.transpose(1, 2) # [B, num_heads, L, head_dim]
|
||||
k = k.transpose(1, 2) # [B, num_heads, L, head_dim]
|
||||
v = v.transpose(1, 2) # [B, num_heads, L, head_dim]
|
||||
|
||||
if qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
# Mixed precision can change the dtype of normed q/k to float32.
|
||||
q = q.to(dtype=qkv.dtype)
|
||||
k = k.to(dtype=qkv.dtype)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
chunk_size=2**15,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = prepare_for_attention(qkv, self.head_dim, qk_norm=self.qk_norm)
|
||||
|
||||
attn_kwargs = dict(
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
scale=self.head_dim**-0.5,
|
||||
)
|
||||
|
||||
if q.size(0) <= chunk_size:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, **attn_kwargs
|
||||
) # [B, num_heads, t, head_dim]
|
||||
else:
|
||||
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
|
||||
# Chunks of 2**16 and up cause an error.
|
||||
x = torch.empty_like(q)
|
||||
for i in range(0, q.size(0), chunk_size):
|
||||
qc = q[i : i + chunk_size]
|
||||
kc = k[i : i + chunk_size]
|
||||
vc = v[i : i + chunk_size]
|
||||
chunk = F.scaled_dot_product_attention(qc, kc, vc, **attn_kwargs)
|
||||
x[i : i + chunk_size].copy_(chunk)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
x = x.transpose(1, 2) # [B, t, num_heads, head_dim]
|
||||
x = x.flatten(2) # [B, t, num_heads * head_dim]
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
*,
|
||||
temporal_expansion: int = 2,
|
||||
spatial_expansion: int = 2,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
blocks = []
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(block_fn(in_channels, **block_kwargs))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.temporal_expansion = temporal_expansion
|
||||
self.spatial_expansion = spatial_expansion
|
||||
|
||||
# Change channels in the final convolution layer.
|
||||
self.proj = Conv1x1(
|
||||
in_channels,
|
||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||
)
|
||||
|
||||
self.d2st = DepthToSpaceTime(
|
||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
|
||||
return ResBlock(
|
||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||
)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks,
|
||||
*,
|
||||
temporal_reduction=2,
|
||||
spatial_reduction=2,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Downsample block for the VAE encoder.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks.
|
||||
temporal_reduction: Temporal reduction factor.
|
||||
spatial_reduction: Spatial reduction factor.
|
||||
"""
|
||||
super().__init__()
|
||||
layers = []
|
||||
|
||||
# Change the channel count in the strided convolution.
|
||||
# This lets the ResBlock have uniform channel count,
|
||||
# as in ConvNeXt.
|
||||
assert in_channels != out_channels
|
||||
layers.append(
|
||||
ContextParallelConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
padding_mode="replicate",
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(block_fn(out_channels, **block_kwargs))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||
num_freqs = (stop - start) // step
|
||||
assert inputs.ndim == 5
|
||||
C = inputs.size(1)
|
||||
|
||||
# Create Base 2 Fourier features.
|
||||
freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
|
||||
assert num_freqs == len(freqs)
|
||||
w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
|
||||
C = inputs.shape[1]
|
||||
w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w.
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
inputs,
|
||||
torch.sin(h),
|
||||
torch.cos(h),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
|
||||
super().__init__()
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Add Fourier features to inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input tensor. Shape: [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
|
||||
"""
|
||||
return add_fourier_features(inputs, self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
out_channels: int = 3,
|
||||
latent_dim: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
temporal_expansions: Optional[List[int]] = None,
|
||||
spatial_expansions: Optional[List[int]] = None,
|
||||
has_attention: List[bool],
|
||||
output_norm: bool = True,
|
||||
nonlinearity: str = "silu",
|
||||
output_nonlinearity: str = "silu",
|
||||
causal: bool = True,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_channels = latent_dim
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.output_nonlinearity = output_nonlinearity
|
||||
assert nonlinearity == "silu"
|
||||
assert causal
|
||||
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
self.num_up_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
first_block.append(
|
||||
block_fn(
|
||||
ch[-1],
|
||||
has_attention=has_attention[-1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*first_block))
|
||||
|
||||
assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
|
||||
assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
|
||||
|
||||
upsample_block_fn = CausalUpsampleBlock
|
||||
|
||||
for i in range(self.num_up_blocks):
|
||||
block = upsample_block_fn(
|
||||
ch[-i - 1],
|
||||
ch[-i - 2],
|
||||
num_res_blocks=num_res_blocks[-i - 2],
|
||||
has_attention=has_attention[-i - 2],
|
||||
temporal_expansion=temporal_expansions[-i - 1],
|
||||
spatial_expansion=spatial_expansions[-i - 1],
|
||||
causal=causal,
|
||||
**block_kwargs,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
assert not output_norm
|
||||
|
||||
# Last block. Preserve channel count.
|
||||
last_block = []
|
||||
for _ in range(num_res_blocks[0]):
|
||||
last_block.append(
|
||||
block_fn(
|
||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
||||
)
|
||||
)
|
||||
blocks.append(nn.Sequential(*last_block))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.output_proj = Conv1x1(ch[0], out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
|
||||
|
||||
Returns:
|
||||
x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
|
||||
T + 1 = (t - 1) * 4.
|
||||
H = h * 16, W = w * 16.
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.output_nonlinearity == "silu":
|
||||
x = F.silu(x, inplace=not self.training)
|
||||
else:
|
||||
assert (
|
||||
not self.output_nonlinearity
|
||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
|
||||
def make_broadcastable(
|
||||
tensor: torch.Tensor,
|
||||
axis: int,
|
||||
ndim: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reshapes the input tensor to have singleton dimensions in all axes except the specified axis.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to reshape. Typically 1D.
|
||||
axis (int): The axis along which the tensor should retain its original size.
|
||||
ndim (int): The total number of dimensions the reshaped tensor should have.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reshaped tensor with shape suitable for broadcasting.
|
||||
"""
|
||||
if tensor.dim() != 1:
|
||||
raise ValueError(f"Expected tensor to be 1D, but got {tensor.dim()}D tensor.")
|
||||
|
||||
axis = (axis + ndim) % ndim # Ensure the axis is within the tensor dimensions
|
||||
shape = [1] * ndim # Start with all dimensions as 1
|
||||
shape[axis] = tensor.size(0) # Set the specified axis to the size of the tensor
|
||||
return tensor.view(*shape)
|
||||
|
||||
|
||||
def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
|
||||
"""
|
||||
Blends two tensors `a` and `b` along the specified axis using linear interpolation.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor): The first tensor.
|
||||
b (torch.Tensor): The second tensor. Must have the same shape as `a`.
|
||||
axis (int): The axis along which to perform the blending.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The blended tensor.
|
||||
"""
|
||||
assert (
|
||||
a.shape == b.shape
|
||||
), f"Tensors must have the same shape, got {a.shape} and {b.shape}"
|
||||
steps = a.size(axis)
|
||||
|
||||
# Create a weight tensor that linearly interpolates from 0 to 1
|
||||
start = 1 / (steps + 1)
|
||||
end = steps / (steps + 1)
|
||||
weight = torch.linspace(start, end, steps=steps, device=a.device, dtype=a.dtype)
|
||||
|
||||
# Make the weight tensor broadcastable across all dimensions
|
||||
weight = make_broadcastable(weight, axis, a.dim())
|
||||
|
||||
# Perform the blending
|
||||
return a * (1 - weight) + b * weight
|
||||
|
||||
|
||||
def blend_horizontal(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
|
||||
if overlap == 0:
|
||||
return torch.cat([a, b], dim=-1)
|
||||
|
||||
assert a.size(-1) >= overlap
|
||||
assert b.size(-1) >= overlap
|
||||
a_left, a_overlap = a[..., :-overlap], a[..., -overlap:]
|
||||
b_overlap, b_right = b[..., :overlap], b[..., overlap:]
|
||||
return torch.cat([a_left, blend(a_overlap, b_overlap, -1), b_right], dim=-1)
|
||||
|
||||
|
||||
def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
|
||||
if overlap == 0:
|
||||
return torch.cat([a, b], dim=-2)
|
||||
|
||||
assert a.size(-2) >= overlap
|
||||
assert b.size(-2) >= overlap
|
||||
a_top, a_overlap = a[..., :-overlap, :], a[..., -overlap:, :]
|
||||
b_overlap, b_bottom = b[..., :overlap, :], b[..., overlap:, :]
|
||||
return torch.cat([a_top, blend(a_overlap, b_overlap, -2), b_bottom], dim=-2)
|
||||
|
||||
|
||||
def nearest_multiple(x: int, multiple: int) -> int:
|
||||
return round(x / multiple) * multiple
|
||||
|
||||
|
||||
def apply_tiled(
|
||||
fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
num_tiles_w: int,
|
||||
num_tiles_h: int,
|
||||
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
|
||||
# Use a factor of 2 times the latent downsample factor.
|
||||
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
|
||||
):
|
||||
if num_tiles_w == 1 and num_tiles_h == 1:
|
||||
return fn(x)
|
||||
|
||||
assert (
|
||||
num_tiles_w & (num_tiles_w - 1) == 0
|
||||
), f"num_tiles_w={num_tiles_w} must be a power of 2"
|
||||
assert (
|
||||
num_tiles_h & (num_tiles_h - 1) == 0
|
||||
), f"num_tiles_h={num_tiles_h} must be a power of 2"
|
||||
|
||||
H, W = x.shape[-2:]
|
||||
assert H % min_block_size == 0
|
||||
assert W % min_block_size == 0
|
||||
ov = overlap // 2
|
||||
assert ov % min_block_size == 0
|
||||
|
||||
if num_tiles_w >= 2:
|
||||
# Subdivide horizontally.
|
||||
half_W = nearest_multiple(W // 2, min_block_size)
|
||||
left = x[..., :, : half_W + ov]
|
||||
right = x[..., :, half_W - ov :]
|
||||
|
||||
assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
|
||||
left = apply_tiled(
|
||||
fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
||||
)
|
||||
right = apply_tiled(
|
||||
fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size
|
||||
)
|
||||
if left is None or right is None:
|
||||
return None
|
||||
|
||||
# If `fn` changed the resolution, adjust the overlap.
|
||||
resample_factor = left.size(-1) / (half_W + ov)
|
||||
out_overlap = int(overlap * resample_factor)
|
||||
|
||||
return blend_horizontal(left, right, out_overlap)
|
||||
|
||||
if num_tiles_h >= 2:
|
||||
# Subdivide vertically.
|
||||
half_H = nearest_multiple(H // 2, min_block_size)
|
||||
top = x[..., : half_H + ov, :]
|
||||
bottom = x[..., half_H - ov :, :]
|
||||
|
||||
assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
|
||||
top = apply_tiled(
|
||||
fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
||||
)
|
||||
bottom = apply_tiled(
|
||||
fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size
|
||||
)
|
||||
if top is None or bottom is None:
|
||||
return None
|
||||
|
||||
# If `fn` changed the resolution, adjust the overlap.
|
||||
resample_factor = top.size(-2) / (half_H + ov)
|
||||
out_overlap = int(overlap * resample_factor)
|
||||
|
||||
return blend_vertical(top, bottom, out_overlap)
|
||||
|
||||
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
|
||||
356
nodes.py
Normal file
356
nodes.py
Normal file
@ -0,0 +1,356 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
from einops import rearrange
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||||
from .mochi_preview.vae.model import Decoder
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
||||
if linear_steps is None:
|
||||
linear_steps = num_steps // 2
|
||||
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
||||
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
||||
quadratic_steps = num_steps - linear_steps
|
||||
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
||||
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
||||
const = quadratic_coef * (linear_steps ** 2)
|
||||
quadratic_sigma_schedule = [
|
||||
quadratic_coef * (i ** 2) + linear_coef * i + const
|
||||
for i in range(linear_steps, num_steps)
|
||||
]
|
||||
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
||||
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
||||
return sigma_schedule
|
||||
|
||||
class DownloadAndLoadMochiModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (
|
||||
[
|
||||
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
||||
],
|
||||
),
|
||||
"vae": (
|
||||
[
|
||||
"mochi_preview_vae_bf16.safetensors",
|
||||
],
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp16", "fp32", "bf16"],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MOCHIMODEL", "MOCHIVAE",)
|
||||
RETURN_NAMES = ("mochi_model", "mochi_vae" )
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||
|
||||
def loadmodel(self, model, vae, precision):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
|
||||
# Transformer model
|
||||
model_download_path = os.path.join(folder_paths.models_dir, 'diffusion_models', 'mochi')
|
||||
model_path = os.path.join(model_download_path, model)
|
||||
|
||||
repo_id = "kijai/Mochi_preview_comfy"
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
log.info(f"Downloading mochi model to: {model_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=[f"*{model}*"],
|
||||
local_dir=model_download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
# VAE
|
||||
vae_download_path = os.path.join(folder_paths.models_dir, 'vae', 'mochi')
|
||||
vae_path = os.path.join(vae_download_path, vae)
|
||||
|
||||
if not os.path.exists(vae_path):
|
||||
log.info(f"Downloading mochi VAE to: {vae_path}")
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=[f"*{vae}*"],
|
||||
local_dir=model_download_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
model = T2VSynthMochiModel(
|
||||
device_id=0,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
)
|
||||
vae = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
temporal_expansions=[1, 2, 3],
|
||||
spatial_expansions=[2, 2, 2],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
has_attention=[False, False, False, False, False],
|
||||
padding_mode="replicate",
|
||||
output_norm=False,
|
||||
nonlinearity="silu",
|
||||
output_nonlinearity="silu",
|
||||
causal=True,
|
||||
)
|
||||
decoder_sd = load_torch_file(vae_path)
|
||||
vae.load_state_dict(decoder_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
del decoder_sd
|
||||
|
||||
return (model, vae,)
|
||||
|
||||
class MochiTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True} ),
|
||||
},
|
||||
"optional": {
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"force_offload": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_NAMES = ("conditioning",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def process(self, clip, prompt, strength=1.0, force_offload=True):
|
||||
max_tokens = 256
|
||||
load_device = mm.text_encoder_device()
|
||||
offload_device = mm.text_encoder_offload_device()
|
||||
#print(clip.tokenizer.t5xxl)
|
||||
clip.tokenizer.t5xxl.pad_to_max_length = True
|
||||
clip.tokenizer.t5xxl.max_length = max_tokens
|
||||
clip.cond_stage_model.t5xxl.return_attention_masks = True
|
||||
clip.cond_stage_model.t5xxl.enable_attention_masks = True
|
||||
clip.cond_stage_model.t5_attention_mask = True
|
||||
clip.cond_stage_model.to(load_device)
|
||||
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
|
||||
|
||||
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
|
||||
|
||||
|
||||
if embeds.shape[1] > 256:
|
||||
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
|
||||
embeds *= strength
|
||||
if force_offload:
|
||||
clip.cond_stage_model.to(offload_device)
|
||||
|
||||
t5_embeds = {
|
||||
"embeds": embeds,
|
||||
"attention_mask": attention_mask["attention_mask"].bool(),
|
||||
}
|
||||
return (t5_embeds, )
|
||||
|
||||
|
||||
class MochiSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MOCHIMODEL",),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"width": ("INT", {"default": 848, "min": 128, "max": 2048, "step": 8}),
|
||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||
"num_frames": ("INT", {"default": 49, "min": 7, "max": 1024, "step": 6}),
|
||||
"steps": ("INT", {"default": 50, "min": 2}),
|
||||
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
RETURN_NAMES = ("model", "samples",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
args = {
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_frames": num_frames,
|
||||
"mochi_args": {
|
||||
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
|
||||
"cfg_schedule": [cfg] * steps,
|
||||
"num_inference_steps": steps,
|
||||
"batch_cfg": False,
|
||||
},
|
||||
"positive_embeds": positive,
|
||||
"negative_embeds": negative,
|
||||
"seed": seed,
|
||||
}
|
||||
latents = model.run(args, stream_results=False)
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
return ({"samples": latents},)
|
||||
|
||||
class MochiDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"vae": ("MOCHIVAE",),
|
||||
"samples": ("LATENT", ),
|
||||
"enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}),
|
||||
},
|
||||
"optional": {
|
||||
"tile_sample_min_height": ("INT", {"default": 240, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile height, default is half the height"}),
|
||||
"tile_sample_min_width": ("INT", {"default": 424, "min": 16, "max": 2048, "step": 8, "tooltip": "Minimum tile width, default is half the width"}),
|
||||
"tile_overlap_factor_height": ("FLOAT", {"default": 0.1666, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"tile_overlap_factor_width": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Auto size based on height and width, default is half the size"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def decode(self, vae, samples, enable_vae_tiling, tile_sample_min_height, tile_sample_min_width, tile_overlap_factor_height, tile_overlap_factor_width, auto_tile_size=True):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
samples = samples["samples"]
|
||||
|
||||
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height if not auto_tile_size else 1 / 6
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width if not auto_tile_size else 1 / 5
|
||||
#7, 13, 19, 25, 31, 37, 43, 49, 55, 61, 67, 73, 79, 85, 91, 97, 103, 109, 115, 121, 127, 133, 139, 145, 151, 157, 163, 169, 175, 181, 187, 193, 199
|
||||
self.num_latent_frames_batch_size = 6
|
||||
|
||||
self.tile_sample_min_height = tile_sample_min_height if not auto_tile_size else samples.shape[3] // 2 * 8
|
||||
self.tile_sample_min_width = tile_sample_min_width if not auto_tile_size else samples.shape[4] // 2 * 8
|
||||
|
||||
self.tile_latent_min_height = int(self.tile_sample_min_height / 8)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / 8)
|
||||
|
||||
vae.to(device)
|
||||
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
||||
if not enable_vae_tiling:
|
||||
samples = vae(samples)
|
||||
else:
|
||||
batch_size, num_channels, num_frames, height, width = samples.shape
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = samples[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
tile = vae(tile)
|
||||
time.append(tile)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
samples = torch.cat(result_rows, dim=3)
|
||||
vae.to(offload_device)
|
||||
#print("samples", samples.shape, samples.dtype, samples.device)
|
||||
|
||||
samples = samples.float()
|
||||
samples = (samples + 1.0) / 2.0
|
||||
samples.clamp_(0.0, 1.0)
|
||||
|
||||
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float()
|
||||
#print(frames.shape)
|
||||
|
||||
return (frames,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||
"MochiSampler": MochiSampler,
|
||||
"MochiDecode": MochiDecode,
|
||||
"MochiTextEncode": MochiTextEncode,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||
"MochiSampler": "Mochi Sampler",
|
||||
"MochiDecode": "Mochi Decode",
|
||||
"MochiTextEncode": "Mochi TextEncode",
|
||||
}
|
||||
18
readme.md
Normal file
18
readme.md
Normal file
@ -0,0 +1,18 @@
|
||||
# ComfyUI wrapper nodes for Mochi video gen: https://github.com/genmoai/models
|
||||
|
||||
|
||||
# WORK IN PROGRESS
|
||||
|
||||
## Requires flash_attn !
|
||||
|
||||
Depending on frame count can fit under 20GB, VAE decoding is heavy and there is experimental tiled decoder (taken from CogVideoX -diffusers code) which allows higher frame counts, so far highest I've done is 97 with the default tile size 2x2 grid.
|
||||
|
||||
Models:
|
||||
|
||||
https://huggingface.co/Kijai/Mochi_preview_comfy/tree/main
|
||||
|
||||
model to: `ComfyUI/models/diffusion_models/mochi`
|
||||
|
||||
vae to: `ComfyUI/models/vae/mochi`
|
||||
|
||||
There is autodownload node (also will be normal loader node)
|
||||
Loading…
x
Reference in New Issue
Block a user