diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a295864 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc +__pycache__ diff --git a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc index 4b28b7e..b7e593d 100644 Binary files a/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc and b/mochi_preview/__pycache__/t2v_synth_mochi.cpython-312.pyc differ diff --git a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc index 63e495a..68ad0e8 100644 Binary files a/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc and b/mochi_preview/dit/joint_model/__pycache__/asymm_models_joint.cpython-312.pyc differ diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 066998f..7ceb8d9 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,78 +1,22 @@ import json -import os import random -from functools import partial from typing import Dict, List from safetensors.torch import load_file import numpy as np import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.utils.data -import yaml -from einops import rearrange, repeat -from omegaconf import OmegaConf from torch import nn -from torch.distributed.fsdp import ( - BackwardPrefetch, - MixedPrecision, - ShardingStrategy, -) -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, -) -from torch.distributed.fsdp.wrap import ( - lambda_auto_wrap_policy, - transformer_auto_wrap_policy, -) -from transformers import T5EncoderModel, T5Tokenizer -from transformers.models.t5.modeling_t5 import T5Block from .dit.joint_model.context_parallel import get_cp_rank_size from .utils import Timer from tqdm import tqdm from comfy.utils import ProgressBar -T5_MODEL = "weights/T5" MAX_T5_TOKEN_LENGTH = 256 -class T5_Tokenizer: - """Wrapper around Hugging Face tokenizer for T5 - - Args: - model_name(str): Name of tokenizer to load. - """ - - def __init__(self): - self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False) - - def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): - """ - Args: - prompt (str): The input text to tokenize. - padding (str): The padding strategy. - truncation (bool): Flag indicating whether to truncate the tokens. - return_tensors (str): Flag indicating whether to return tensors. - max_length (int): The max length of the tokens. - """ - assert ( - not max_length or max_length == MAX_T5_TOKEN_LENGTH - ), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5." - - tokenized_output = self.tokenizer( - prompt, - padding=padding, - max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here. - truncation=truncation, - return_tensors=return_tensors, - return_attention_mask=True, - ) - - return tokenized_output - - def unnormalize_latents( z: torch.Tensor, mean: torch.Tensor, @@ -93,27 +37,6 @@ def unnormalize_latents( assert z.size(1) == mean.size(0) == std.size(0) return z * std.to(z) + mean.to(z) - -def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP: - model = FSDP( - model, - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=MixedPrecision( - param_dtype=param_dtype, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, - ), - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - limit_all_gathers=True, - device_id=device_id, - sync_module_states=True, - use_orig_params=True, - ) - torch.cuda.synchronize() - return model - - def compute_packed_indices( N: int, text_mask: List[torch.Tensor], @@ -189,12 +112,6 @@ class T2VSynthMochiModel: t = Timer() self.device = torch.device(device_id) - #self.t5_tokenizer = T5_Tokenizer() - - # with t("load_text_encs"): - # t5_enc = T5EncoderModel.from_pretrained(T5_MODEL) - # self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu") - with t("construct_dit"): from .dit.joint_model.asymm_models_joint import ( AsymmDiTJoint, @@ -223,15 +140,15 @@ class T2VSynthMochiModel: model.load_state_dict(load_file(dit_checkpoint_path)) - with t("fsdp_dit"): - self.dit = model - self.dit.eval() - for name, param in self.dit.named_parameters(): - params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} - if not any(keyword in name for keyword in params_to_keep): - param.data = param.data.to(torch.float8_e4m3fn) - else: - param.data = param.data.to(torch.bfloat16) + #with t("fsdp_dit"): + self.dit = model + self.dit.eval() + for name, param in self.dit.named_parameters(): + params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} + if not any(keyword in name for keyword in params_to_keep): + param.data = param.data.to(torch.float8_e4m3fn) + else: + param.data = param.data.to(torch.bfloat16) vae_stats = json.load(open(vae_stats_path)) diff --git a/nodes.py b/nodes.py index e6c5459..ac77a53 100644 --- a/nodes.py +++ b/nodes.py @@ -1,17 +1,10 @@ import os import torch -import torch.nn as nn import folder_paths import comfy.model_management as mm from comfy.utils import ProgressBar, load_torch_file from einops import rearrange - -from contextlib import nullcontext - -from PIL import Image -import numpy as np -import json - +from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) @@ -295,11 +288,11 @@ class MochiDecode: # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] - for i in range(0, height, overlap_height): + for i in tqdm(range(0, height, overlap_height), desc="Processing rows"): row = [] - for j in range(0, width, overlap_width): + for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False): time = [] - for k in range(num_frames // frame_batch_size): + for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) end_frame = frame_batch_size * (k + 1) + remaining_frames @@ -316,9 +309,9 @@ class MochiDecode: rows.append(row) result_rows = [] - for i, row in enumerate(rows): + for i, row in enumerate(tqdm(rows, desc="Blending rows")): result_row = [] - for j, tile in enumerate(row): + for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: