345 lines
14 KiB
Python
345 lines
14 KiB
Python
from typing import Dict, List, Optional, Union
|
|
from einops import rearrange
|
|
|
|
#temporary patch to fix torch compile bug in Windows
|
|
def patched_write_atomic(
|
|
path_: str,
|
|
content: Union[str, bytes],
|
|
make_dirs: bool = False,
|
|
encode_utf_8: bool = False,
|
|
) -> None:
|
|
# Write into temporary file first to avoid conflicts between threads
|
|
# Avoid using a named temporary file, as those have restricted permissions
|
|
from pathlib import Path
|
|
import os
|
|
import shutil
|
|
import threading
|
|
assert isinstance(
|
|
content, (str, bytes)
|
|
), "Only strings and byte arrays can be saved in the cache"
|
|
path = Path(path_)
|
|
if make_dirs:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
|
|
write_mode = "w" if isinstance(content, str) else "wb"
|
|
with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f:
|
|
f.write(content)
|
|
shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files
|
|
os.remove(tmp_path)
|
|
try:
|
|
import torch._inductor.codecache
|
|
torch._inductor.codecache.write_atomic = patched_write_atomic
|
|
except:
|
|
pass
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
from tqdm import tqdm
|
|
from comfy.utils import ProgressBar, load_torch_file
|
|
import comfy.model_management as mm
|
|
|
|
from contextlib import nullcontext
|
|
try:
|
|
from accelerate import init_empty_weights
|
|
from accelerate.utils import set_module_tensor_to_device
|
|
is_accelerate_available = True
|
|
except:
|
|
is_accelerate_available = False
|
|
pass
|
|
|
|
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
|
|
|
import logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
log = logging.getLogger(__name__)
|
|
|
|
MAX_T5_TOKEN_LENGTH = 256
|
|
|
|
def fft(tensor):
|
|
tensor_fft = torch.fft.fft2(tensor)
|
|
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
|
B, C, H, W = tensor.size()
|
|
radius = min(H, W) // 5
|
|
|
|
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
|
|
center_x, center_y = W // 2, H // 2
|
|
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2
|
|
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device)
|
|
high_freq_mask = ~low_freq_mask
|
|
|
|
low_freq_fft = tensor_fft_shifted * low_freq_mask
|
|
high_freq_fft = tensor_fft_shifted * high_freq_mask
|
|
|
|
return low_freq_fft, high_freq_fft
|
|
|
|
class T2VSynthMochiModel:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
device: torch.device,
|
|
offload_device: torch.device,
|
|
dit_checkpoint_path: str,
|
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
|
fp8_fastmode: bool = False,
|
|
attention_mode: str = "sdpa",
|
|
rms_norm_func: str = "default",
|
|
compile_args: Optional[Dict] = None,
|
|
cublas_ops: Optional[bool] = False,
|
|
):
|
|
super().__init__()
|
|
self.device = device
|
|
self.weight_dtype = weight_dtype
|
|
self.offload_device = offload_device
|
|
|
|
logging.info("Initializing model...")
|
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
|
model = AsymmDiTJoint(
|
|
depth=48,
|
|
patch_size=2,
|
|
num_heads=24,
|
|
hidden_size_x=3072,
|
|
hidden_size_y=1536,
|
|
mlp_ratio_x=4.0,
|
|
mlp_ratio_y=4.0,
|
|
in_channels=12,
|
|
qk_norm=True,
|
|
qkv_bias=False,
|
|
out_bias=True,
|
|
patch_embed_bias=True,
|
|
timestep_mlp_bias=True,
|
|
timestep_scale=1000.0,
|
|
t5_feat_dim=4096,
|
|
t5_token_length=256,
|
|
rope_theta=10000.0,
|
|
attention_mode=attention_mode,
|
|
rms_norm_func=rms_norm_func,
|
|
)
|
|
|
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
|
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
|
dit_sd = load_torch_file(dit_checkpoint_path)
|
|
|
|
#comfy format
|
|
prefix = "model.diffusion_model."
|
|
first_key = next(iter(dit_sd), None)
|
|
if first_key and first_key.startswith(prefix):
|
|
new_dit_sd = {
|
|
key[len(prefix):] if key.startswith(prefix) else key: value
|
|
for key, value in dit_sd.items()
|
|
}
|
|
dit_sd = new_dit_sd
|
|
|
|
if "gguf" in dit_checkpoint_path.lower():
|
|
logging.info("Loading GGUF model state_dict...")
|
|
from .. import mz_gguf_loader
|
|
import importlib
|
|
importlib.reload(mz_gguf_loader)
|
|
with mz_gguf_loader.quantize_lazy_load():
|
|
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu", cublas_ops=cublas_ops)
|
|
elif is_accelerate_available:
|
|
logging.info("Using accelerate to load and assign model weights to device...")
|
|
for name, param in model.named_parameters():
|
|
if not any(keyword in name for keyword in params_to_keep):
|
|
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
|
else:
|
|
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
|
else:
|
|
logging.info("Loading state_dict without accelerate...")
|
|
model.load_state_dict(dit_sd)
|
|
for name, param in model.named_parameters():
|
|
if not any(keyword in name for keyword in params_to_keep):
|
|
param.data = param.data.to(weight_dtype)
|
|
else:
|
|
param.data = param.data.to(torch.bfloat16)
|
|
|
|
if fp8_fastmode:
|
|
from ..fp8_optimization import convert_fp8_linear
|
|
convert_fp8_linear(model, torch.bfloat16)
|
|
|
|
model = model.eval().to(self.device)
|
|
|
|
#torch.compile
|
|
if compile_args is not None:
|
|
if compile_args["compile_dit"]:
|
|
for i, block in enumerate(model.blocks):
|
|
model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
|
if compile_args["compile_final_layer"]:
|
|
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"])
|
|
|
|
self.dit = model
|
|
|
|
def move_to_device_(self, sample):
|
|
if isinstance(sample, dict):
|
|
for key in sample.keys():
|
|
if isinstance(sample[key], torch.Tensor):
|
|
sample[key] = sample[key].to(self.device, non_blocking=True)
|
|
|
|
def run(self, args):
|
|
torch.manual_seed(args["seed"])
|
|
torch.cuda.manual_seed(args["seed"])
|
|
|
|
generator = torch.Generator(device=torch.device("cpu"))
|
|
generator.manual_seed(args["seed"])
|
|
|
|
num_frames = args["num_frames"]
|
|
height = args["height"]
|
|
width = args["width"]
|
|
in_samples = args["samples"]
|
|
|
|
sample_steps = args["mochi_args"]["num_inference_steps"]
|
|
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
|
assert (
|
|
len(cfg_schedule) == sample_steps
|
|
), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}"
|
|
sigma_schedule = args["mochi_args"].get("sigma_schedule")
|
|
if sigma_schedule:
|
|
assert (
|
|
len(sigma_schedule) == sample_steps + 1
|
|
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
|
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
|
|
|
|
from ..latent_preview import prepare_callback
|
|
callback = prepare_callback(self.dit, sample_steps)
|
|
|
|
# create z
|
|
spatial_downsample = 8
|
|
temporal_downsample = 6
|
|
in_channels = 12
|
|
B = 1
|
|
C = in_channels
|
|
T = (num_frames - 1) // temporal_downsample + 1
|
|
H = height // spatial_downsample
|
|
W = width // spatial_downsample
|
|
|
|
z = torch.randn(
|
|
(B, C, T, H, W),
|
|
device=torch.device("cpu"),
|
|
generator=generator,
|
|
dtype=torch.float32,
|
|
).to(self.device)
|
|
if in_samples is not None:
|
|
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
|
|
|
|
sample = {
|
|
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
|
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
|
|
}
|
|
sample_null = {
|
|
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
|
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
|
}
|
|
print(args["fastercache"])
|
|
if args["fastercache"]:
|
|
print("Using fastercache")
|
|
self.fastercache_start_step = args["fastercache"]["start_step"]
|
|
self.fastercache_lf_step = args["fastercache"]["lf_step"]
|
|
self.fastercache_hf_step = args["fastercache"]["hf_step"]
|
|
else:
|
|
self.fastercache_start_step = 1000
|
|
self.fastercache_counter = 0
|
|
|
|
def model_fn(*, z, sigma, cfg_scale):
|
|
nonlocal sample, sample_null
|
|
if cfg_scale != 1.0:
|
|
if args["fastercache"]:
|
|
self.fastercache_counter+=1
|
|
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
|
out_cond = self.dit(
|
|
z,
|
|
sigma,
|
|
**sample,
|
|
fastercache = args["fastercache"],
|
|
fastercache_counter=self.fastercache_counter)
|
|
|
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
|
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
lf_c, hf_c = fft(cond.float())
|
|
if self.fastercache_counter <= self.fastercache_lf_step:
|
|
self.delta_lf = self.delta_lf * 1.1
|
|
if self.fastercache_counter >= self.fastercache_hf_step:
|
|
self.delta_hf = self.delta_hf * 1.1
|
|
|
|
new_hf_uc = self.delta_hf + hf_c
|
|
new_lf_uc = self.delta_lf + lf_c
|
|
|
|
combine_uc = new_lf_uc + new_hf_uc
|
|
combined_fft = torch.fft.ifftshift(combine_uc)
|
|
recovered_uncond = torch.fft.ifft2(combined_fft).real
|
|
recovered_uncond = rearrange(recovered_uncond.to(out_cond.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
|
|
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
|
else:
|
|
out_cond = self.dit(
|
|
z,
|
|
sigma,
|
|
**sample,
|
|
fastercache = args["fastercache"],
|
|
fastercache_counter=self.fastercache_counter)
|
|
|
|
out_uncond = self.dit(
|
|
z,
|
|
sigma,
|
|
**sample_null,
|
|
fastercache = args["fastercache"],
|
|
fastercache_counter=self.fastercache_counter)
|
|
|
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
|
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
|
|
|
lf_c, hf_c = fft(cond)
|
|
lf_uc, hf_uc = fft(uncond)
|
|
|
|
self.delta_lf = lf_uc - lf_c
|
|
self.delta_hf = hf_uc - hf_c
|
|
|
|
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
|
else: #handle cfg 1.0
|
|
out_cond = self.dit(
|
|
z,
|
|
sigma,
|
|
**sample,
|
|
fastercache = args["fastercache"],
|
|
fastercache_counter=self.fastercache_counter)
|
|
return out_cond
|
|
|
|
|
|
comfy_pbar = ProgressBar(sample_steps)
|
|
|
|
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
|
autocast_dtype = torch.float16
|
|
else:
|
|
autocast_dtype = torch.bfloat16
|
|
|
|
self.dit.to(self.device)
|
|
|
|
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
|
for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps):
|
|
sigma = sigma_schedule[i]
|
|
dsigma = sigma - sigma_schedule[i + 1]
|
|
|
|
# `pred` estimates `z_0 - eps`.
|
|
pred = model_fn(
|
|
z=z,
|
|
sigma=torch.full([B], sigma, device=z.device),
|
|
cfg_scale=cfg_schedule[i],
|
|
)
|
|
z = z + dsigma * pred.to(z)
|
|
if callback is not None:
|
|
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
|
else:
|
|
comfy_pbar.update(1)
|
|
|
|
if args["fastercache"] is not None:
|
|
for block in self.dit.blocks:
|
|
if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None:
|
|
block.cached_x_attention = None
|
|
block.cached_y_attention = None
|
|
|
|
self.dit.to(self.offload_device)
|
|
mm.soft_empty_cache()
|
|
logging.info(f"samples shape: {z.shape}")
|
|
return z
|