This commit is contained in:
kijai 2024-10-23 15:45:20 +03:00
parent b80cb4a691
commit 4efb7c85df
5 changed files with 17 additions and 105 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.pyc
__pycache__

View File

@ -1,78 +1,22 @@
import json import json
import os
import random import random
from functools import partial
from typing import Dict, List from typing import Dict, List
from safetensors.torch import load_file from safetensors.torch import load_file
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import yaml
from einops import rearrange, repeat
from omegaconf import OmegaConf
from torch import nn from torch import nn
from torch.distributed.fsdp import (
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Block
from .dit.joint_model.context_parallel import get_cp_rank_size from .dit.joint_model.context_parallel import get_cp_rank_size
from .utils import Timer from .utils import Timer
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
T5_MODEL = "weights/T5"
MAX_T5_TOKEN_LENGTH = 256 MAX_T5_TOKEN_LENGTH = 256
class T5_Tokenizer:
"""Wrapper around Hugging Face tokenizer for T5
Args:
model_name(str): Name of tokenizer to load.
"""
def __init__(self):
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
"""
Args:
prompt (str): The input text to tokenize.
padding (str): The padding strategy.
truncation (bool): Flag indicating whether to truncate the tokens.
return_tensors (str): Flag indicating whether to return tensors.
max_length (int): The max length of the tokens.
"""
assert (
not max_length or max_length == MAX_T5_TOKEN_LENGTH
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
truncation=truncation,
return_tensors=return_tensors,
return_attention_mask=True,
)
return tokenized_output
def unnormalize_latents( def unnormalize_latents(
z: torch.Tensor, z: torch.Tensor,
mean: torch.Tensor, mean: torch.Tensor,
@ -93,27 +37,6 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0) assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z) return z * std.to(z) + mean.to(z)
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
),
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=device_id,
sync_module_states=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def compute_packed_indices( def compute_packed_indices(
N: int, N: int,
text_mask: List[torch.Tensor], text_mask: List[torch.Tensor],
@ -189,12 +112,6 @@ class T2VSynthMochiModel:
t = Timer() t = Timer()
self.device = torch.device(device_id) self.device = torch.device(device_id)
#self.t5_tokenizer = T5_Tokenizer()
# with t("load_text_encs"):
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
with t("construct_dit"): with t("construct_dit"):
from .dit.joint_model.asymm_models_joint import ( from .dit.joint_model.asymm_models_joint import (
AsymmDiTJoint, AsymmDiTJoint,
@ -223,15 +140,15 @@ class T2VSynthMochiModel:
model.load_state_dict(load_file(dit_checkpoint_path)) model.load_state_dict(load_file(dit_checkpoint_path))
with t("fsdp_dit"): #with t("fsdp_dit"):
self.dit = model self.dit = model
self.dit.eval() self.dit.eval()
for name, param in self.dit.named_parameters(): for name, param in self.dit.named_parameters():
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
if not any(keyword in name for keyword in params_to_keep): if not any(keyword in name for keyword in params_to_keep):
param.data = param.data.to(torch.float8_e4m3fn) param.data = param.data.to(torch.float8_e4m3fn)
else: else:
param.data = param.data.to(torch.bfloat16) param.data = param.data.to(torch.bfloat16)
vae_stats = json.load(open(vae_stats_path)) vae_stats = json.load(open(vae_stats_path))

View File

@ -1,17 +1,10 @@
import os import os
import torch import torch
import torch.nn as nn
import folder_paths import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
from einops import rearrange from einops import rearrange
from tqdm import tqdm
from contextlib import nullcontext
from PIL import Image
import numpy as np
import json
import logging import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -295,11 +288,11 @@ class MochiDecode:
# Split z into overlapping tiles and decode them separately. # Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles. # The tiles have an overlap to avoid seams between tiles.
rows = [] rows = []
for i in range(0, height, overlap_height): for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
row = [] row = []
for j in range(0, width, overlap_width): for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
time = [] time = []
for k in range(num_frames // frame_batch_size): for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
remaining_frames = num_frames % frame_batch_size remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames end_frame = frame_batch_size * (k + 1) + remaining_frames
@ -316,9 +309,9 @@ class MochiDecode:
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
for i, row in enumerate(rows): for i, row in enumerate(tqdm(rows, desc="Blending rows")):
result_row = [] result_row = []
for j, tile in enumerate(row): for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
# blend the above tile and the left tile # blend the above tile and the left tile
# to the current tile and add the current tile to the result row # to the current tile and add the current tile to the result row
if i > 0: if i > 0: