cleanup
This commit is contained in:
parent
b80cb4a691
commit
4efb7c85df
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*.pyc
|
||||
__pycache__
|
||||
Binary file not shown.
Binary file not shown.
@ -1,78 +1,22 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from safetensors.torch import load_file
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import yaml
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import (
|
||||
BackwardPrefetch,
|
||||
MixedPrecision,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
lambda_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from transformers.models.t5.modeling_t5 import T5Block
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from .utils import Timer
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
T5_MODEL = "weights/T5"
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
class T5_Tokenizer:
|
||||
"""Wrapper around Hugging Face tokenizer for T5
|
||||
|
||||
Args:
|
||||
model_name(str): Name of tokenizer to load.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
|
||||
|
||||
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
|
||||
"""
|
||||
Args:
|
||||
prompt (str): The input text to tokenize.
|
||||
padding (str): The padding strategy.
|
||||
truncation (bool): Flag indicating whether to truncate the tokens.
|
||||
return_tensors (str): Flag indicating whether to return tensors.
|
||||
max_length (int): The max length of the tokens.
|
||||
"""
|
||||
assert (
|
||||
not max_length or max_length == MAX_T5_TOKEN_LENGTH
|
||||
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
|
||||
|
||||
tokenized_output = self.tokenizer(
|
||||
prompt,
|
||||
padding=padding,
|
||||
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
|
||||
truncation=truncation,
|
||||
return_tensors=return_tensors,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
return tokenized_output
|
||||
|
||||
|
||||
def unnormalize_latents(
|
||||
z: torch.Tensor,
|
||||
mean: torch.Tensor,
|
||||
@ -93,27 +37,6 @@ def unnormalize_latents(
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
|
||||
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
),
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
device_id=device_id,
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
return model
|
||||
|
||||
|
||||
def compute_packed_indices(
|
||||
N: int,
|
||||
text_mask: List[torch.Tensor],
|
||||
@ -189,12 +112,6 @@ class T2VSynthMochiModel:
|
||||
t = Timer()
|
||||
self.device = torch.device(device_id)
|
||||
|
||||
#self.t5_tokenizer = T5_Tokenizer()
|
||||
|
||||
# with t("load_text_encs"):
|
||||
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
|
||||
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
|
||||
|
||||
with t("construct_dit"):
|
||||
from .dit.joint_model.asymm_models_joint import (
|
||||
AsymmDiTJoint,
|
||||
@ -223,15 +140,15 @@ class T2VSynthMochiModel:
|
||||
|
||||
model.load_state_dict(load_file(dit_checkpoint_path))
|
||||
|
||||
with t("fsdp_dit"):
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
for name, param in self.dit.named_parameters():
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
#with t("fsdp_dit"):
|
||||
self.dit = model
|
||||
self.dit.eval()
|
||||
for name, param in self.dit.named_parameters():
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
|
||||
|
||||
vae_stats = json.load(open(vae_stats_path))
|
||||
|
||||
19
nodes.py
19
nodes.py
@ -1,17 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
from einops import rearrange
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
@ -295,11 +288,11 @@ class MochiDecode:
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
@ -316,9 +309,9 @@ class MochiDecode:
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user