This commit is contained in:
kijai 2024-10-23 15:45:20 +03:00
parent b80cb4a691
commit 4efb7c85df
5 changed files with 17 additions and 105 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.pyc
__pycache__

View File

@ -1,78 +1,22 @@
import json
import os
import random
from functools import partial
from typing import Dict, List
from safetensors.torch import load_file
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import yaml
from einops import rearrange, repeat
from omegaconf import OmegaConf
from torch import nn
from torch.distributed.fsdp import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Block
from .dit.joint_model.context_parallel import get_cp_rank_size
from .utils import Timer
from tqdm import tqdm
from comfy.utils import ProgressBar
T5_MODEL = "weights/T5"
MAX_T5_TOKEN_LENGTH = 256
class T5_Tokenizer:
"""Wrapper around Hugging Face tokenizer for T5
Args:
model_name(str): Name of tokenizer to load.
"""
def __init__(self):
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
"""
Args:
prompt (str): The input text to tokenize.
padding (str): The padding strategy.
truncation (bool): Flag indicating whether to truncate the tokens.
return_tensors (str): Flag indicating whether to return tensors.
max_length (int): The max length of the tokens.
"""
assert (
not max_length or max_length == MAX_T5_TOKEN_LENGTH
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
truncation=truncation,
return_tensors=return_tensors,
return_attention_mask=True,
)
return tokenized_output
def unnormalize_latents(
z: torch.Tensor,
mean: torch.Tensor,
@ -93,27 +37,6 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
),
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=device_id,
sync_module_states=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],
@ -189,12 +112,6 @@ class T2VSynthMochiModel:
t = Timer()
self.device = torch.device(device_id)
#self.t5_tokenizer = T5_Tokenizer()
# with t("load_text_encs"):
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
with t("construct_dit"):
from .dit.joint_model.asymm_models_joint import (
AsymmDiTJoint,
@ -223,7 +140,7 @@ class T2VSynthMochiModel:
model.load_state_dict(load_file(dit_checkpoint_path))
with t("fsdp_dit"):
#with t("fsdp_dit"):
self.dit = model
self.dit.eval()
for name, param in self.dit.named_parameters():

View File

@ -1,17 +1,10 @@
import os
import torch
import torch.nn as nn
import folder_paths
import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
from einops import rearrange
from contextlib import nullcontext
from PIL import Image
import numpy as np
import json
from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
@ -295,11 +288,11 @@ class MochiDecode:
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = []
for j in range(0, width, overlap_width):
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
time = []
for k in range(num_frames // frame_batch_size):
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
@ -316,9 +309,9 @@ class MochiDecode:
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = []
for j, tile in enumerate(row):
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0: