cleanup
This commit is contained in:
parent
b80cb4a691
commit
4efb7c85df
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
*.pyc
|
||||||
|
__pycache__
|
||||||
Binary file not shown.
Binary file not shown.
@ -1,78 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from functools import partial
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import yaml
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.fsdp import (
|
|
||||||
BackwardPrefetch,
|
|
||||||
MixedPrecision,
|
|
||||||
ShardingStrategy,
|
|
||||||
)
|
|
||||||
from torch.distributed.fsdp import (
|
|
||||||
FullyShardedDataParallel as FSDP,
|
|
||||||
)
|
|
||||||
from torch.distributed.fsdp.wrap import (
|
|
||||||
lambda_auto_wrap_policy,
|
|
||||||
transformer_auto_wrap_policy,
|
|
||||||
)
|
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
|
||||||
from transformers.models.t5.modeling_t5 import T5Block
|
|
||||||
|
|
||||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from .utils import Timer
|
from .utils import Timer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
T5_MODEL = "weights/T5"
|
|
||||||
MAX_T5_TOKEN_LENGTH = 256
|
MAX_T5_TOKEN_LENGTH = 256
|
||||||
|
|
||||||
class T5_Tokenizer:
|
|
||||||
"""Wrapper around Hugging Face tokenizer for T5
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name(str): Name of tokenizer to load.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
|
|
||||||
|
|
||||||
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
prompt (str): The input text to tokenize.
|
|
||||||
padding (str): The padding strategy.
|
|
||||||
truncation (bool): Flag indicating whether to truncate the tokens.
|
|
||||||
return_tensors (str): Flag indicating whether to return tensors.
|
|
||||||
max_length (int): The max length of the tokens.
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
not max_length or max_length == MAX_T5_TOKEN_LENGTH
|
|
||||||
), f"Max length must be {MAX_T5_TOKEN_LENGTH} for T5."
|
|
||||||
|
|
||||||
tokenized_output = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding=padding,
|
|
||||||
max_length=MAX_T5_TOKEN_LENGTH, # Max token length for T5 is set here.
|
|
||||||
truncation=truncation,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
return_attention_mask=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tokenized_output
|
|
||||||
|
|
||||||
|
|
||||||
def unnormalize_latents(
|
def unnormalize_latents(
|
||||||
z: torch.Tensor,
|
z: torch.Tensor,
|
||||||
mean: torch.Tensor,
|
mean: torch.Tensor,
|
||||||
@ -93,27 +37,6 @@ def unnormalize_latents(
|
|||||||
assert z.size(1) == mean.size(0) == std.size(0)
|
assert z.size(1) == mean.size(0) == std.size(0)
|
||||||
return z * std.to(z) + mean.to(z)
|
return z * std.to(z) + mean.to(z)
|
||||||
|
|
||||||
|
|
||||||
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
|
|
||||||
model = FSDP(
|
|
||||||
model,
|
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
||||||
mixed_precision=MixedPrecision(
|
|
||||||
param_dtype=param_dtype,
|
|
||||||
reduce_dtype=torch.float32,
|
|
||||||
buffer_dtype=torch.float32,
|
|
||||||
),
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
|
||||||
limit_all_gathers=True,
|
|
||||||
device_id=device_id,
|
|
||||||
sync_module_states=True,
|
|
||||||
use_orig_params=True,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def compute_packed_indices(
|
def compute_packed_indices(
|
||||||
N: int,
|
N: int,
|
||||||
text_mask: List[torch.Tensor],
|
text_mask: List[torch.Tensor],
|
||||||
@ -189,12 +112,6 @@ class T2VSynthMochiModel:
|
|||||||
t = Timer()
|
t = Timer()
|
||||||
self.device = torch.device(device_id)
|
self.device = torch.device(device_id)
|
||||||
|
|
||||||
#self.t5_tokenizer = T5_Tokenizer()
|
|
||||||
|
|
||||||
# with t("load_text_encs"):
|
|
||||||
# t5_enc = T5EncoderModel.from_pretrained(T5_MODEL)
|
|
||||||
# self.t5_enc = t5_enc.eval().to(torch.bfloat16).to("cpu")
|
|
||||||
|
|
||||||
with t("construct_dit"):
|
with t("construct_dit"):
|
||||||
from .dit.joint_model.asymm_models_joint import (
|
from .dit.joint_model.asymm_models_joint import (
|
||||||
AsymmDiTJoint,
|
AsymmDiTJoint,
|
||||||
@ -223,15 +140,15 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
model.load_state_dict(load_file(dit_checkpoint_path))
|
model.load_state_dict(load_file(dit_checkpoint_path))
|
||||||
|
|
||||||
with t("fsdp_dit"):
|
#with t("fsdp_dit"):
|
||||||
self.dit = model
|
self.dit = model
|
||||||
self.dit.eval()
|
self.dit.eval()
|
||||||
for name, param in self.dit.named_parameters():
|
for name, param in self.dit.named_parameters():
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
else:
|
else:
|
||||||
param.data = param.data.to(torch.bfloat16)
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
vae_stats = json.load(open(vae_stats_path))
|
vae_stats = json.load(open(vae_stats_path))
|
||||||
|
|||||||
19
nodes.py
19
nodes.py
@ -1,17 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from tqdm import tqdm
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -295,11 +288,11 @@ class MochiDecode:
|
|||||||
# Split z into overlapping tiles and decode them separately.
|
# Split z into overlapping tiles and decode them separately.
|
||||||
# The tiles have an overlap to avoid seams between tiles.
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
rows = []
|
rows = []
|
||||||
for i in range(0, height, overlap_height):
|
for i in tqdm(range(0, height, overlap_height), desc="Processing rows"):
|
||||||
row = []
|
row = []
|
||||||
for j in range(0, width, overlap_width):
|
for j in tqdm(range(0, width, overlap_width), desc="Processing columns", leave=False):
|
||||||
time = []
|
time = []
|
||||||
for k in range(num_frames // frame_batch_size):
|
for k in tqdm(range(num_frames // frame_batch_size), desc="Processing frames", leave=False):
|
||||||
remaining_frames = num_frames % frame_batch_size
|
remaining_frames = num_frames % frame_batch_size
|
||||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||||
@ -316,9 +309,9 @@ class MochiDecode:
|
|||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
result_rows = []
|
result_rows = []
|
||||||
for i, row in enumerate(rows):
|
for i, row in enumerate(tqdm(rows, desc="Blending rows")):
|
||||||
result_row = []
|
result_row = []
|
||||||
for j, tile in enumerate(row):
|
for j, tile in enumerate(tqdm(row, desc="Blending tiles", leave=False)):
|
||||||
# blend the above tile and the left tile
|
# blend the above tile and the left tile
|
||||||
# to the current tile and add the current tile to the result row
|
# to the current tile and add the current tile to the result row
|
||||||
if i > 0:
|
if i > 0:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user