import json from typing import Dict, List, Optional, Union #temporary patch to fix bug in Windows def patched_write_atomic( path_: str, content: Union[str, bytes], make_dirs: bool = False, encode_utf_8: bool = False, ) -> None: # Write into temporary file first to avoid conflicts between threads # Avoid using a named temporary file, as those have restricted permissions from pathlib import Path import os import shutil import threading assert isinstance( content, (str, bytes) ), "Only strings and byte arrays can be saved in the cache" path = Path(path_) if make_dirs: path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: f.write(content) shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files os.remove(tmp_path) try: import torch._inductor.codecache torch._inductor.codecache.write_atomic = patched_write_atomic except: pass import torch import torch.nn.functional as F import torch.utils.data from einops import rearrange, repeat from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file import comfy.model_management as mm from contextlib import nullcontext try: from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device is_accelerate_available = True except: is_accelerate_available = False pass from .dit.joint_model.asymm_models_joint import AsymmDiTJoint import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) MAX_T5_TOKEN_LENGTH = 256 def unnormalize_latents( z: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, ) -> torch.Tensor: """Unnormalize latents. Useful for decoding DiT samples. Args: z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float Returns: torch.Tensor: [B, C_z, T_z, H_z, W_z], float """ mean = mean[:, None, None, None] std = std[:, None, None, None] assert z.ndim == 5 assert z.size(1) == mean.size(0) == std.size(0) return z * std.to(z) + mean.to(z) def compute_packed_indices( N: int, text_mask: List[torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 Args: N: Number of visual tokens. text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. Returns: packed_indices: Dict with keys for Flash Attention: - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) in the packed sequence. - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. """ # Create an expanded token mask saying which tokens are valid across both visual and text tokens. assert N > 0 and len(text_mask) == 1 text_mask = text_mask[0] mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L) seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) valid_token_indices = torch.nonzero( mask.flatten(), as_tuple=False ).flatten() # up to (B * (N + L),) assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,) cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) ) max_seqlen_in_batch = seqlens_in_batch.max().item() return { "cu_seqlens_kv": cu_seqlens, "max_seqlen_in_batch_kv": max_seqlen_in_batch, "valid_token_indices_kv": valid_token_indices, } class T2VSynthMochiModel: def __init__( self, *, device: torch.device, offload_device: torch.device, vae_stats_path: str, dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, attention_mode: str = "sdpa", compile_args: Optional[Dict] = None, ): super().__init__() self.device = device self.offload_device = offload_device logging.info("Initializing model...") with (init_empty_weights() if is_accelerate_available else nullcontext()): model = AsymmDiTJoint( depth=48, patch_size=2, num_heads=24, hidden_size_x=3072, hidden_size_y=1536, mlp_ratio_x=4.0, mlp_ratio_y=4.0, in_channels=12, qk_norm=True, qkv_bias=False, out_bias=True, patch_embed_bias=True, timestep_mlp_bias=True, timestep_scale=1000.0, t5_feat_dim=4096, t5_token_length=256, rope_theta=10000.0, attention_mode=attention_mode, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} logging.info(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) if "gguf" in dit_checkpoint_path.lower(): logging.info("Loading GGUF model state_dict...") from .. import mz_gguf_loader import importlib importlib.reload(mz_gguf_loader) with mz_gguf_loader.quantize_lazy_load(): model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu") elif is_accelerate_available: logging.info("Using accelerate to load and assign model weights to device...") for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) else: set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) else: logging.info("Loading state_dict without accelerate...") model.load_state_dict(dit_sd) for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(weight_dtype) else: param.data = param.data.to(torch.bfloat16) if fp8_fastmode: from ..fp8_optimization import convert_fp8_linear convert_fp8_linear(model, torch.bfloat16) model = model.eval().to(self.device) #torch.compile if compile_args is not None: if compile_args["compile_dit"]: for i, block in enumerate(model.blocks): model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) if compile_args["compile_final_layer"]: model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) self.dit = model vae_stats = json.load(open(vae_stats_path)) self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) def get_conditioning(self, prompts, *, zero_last_n_prompts: int): B = len(prompts) assert ( 0 <= zero_last_n_prompts <= B ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" tokenize_kwargs = dict( prompt=prompts, padding="max_length", return_tensors="pt", truncation=True, ) t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH) caption_input_ids_t5 = t5_toks["input_ids"] caption_attention_mask_t5 = t5_toks["attention_mask"].bool() del t5_toks assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) if zero_last_n_prompts > 0: # Zero the last N prompts caption_input_ids_t5[-zero_last_n_prompts:] = 0 caption_attention_mask_t5[-zero_last_n_prompts:] = False caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True) caption_attention_mask_t5 = caption_attention_mask_t5.to( self.device, non_blocking=True ) y_mask = [caption_attention_mask_t5] y_feat = [] self.t5_enc.to(self.device) y_feat.append( self.t5_enc( caption_input_ids_t5, caption_attention_mask_t5 ).last_hidden_state.detach().to(torch.float32) ) self.t5_enc.to(self.offload_device) # Sometimes returns a tensor, othertimes a tuple, not sure why # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) return dict(y_mask=y_mask, y_feat=y_feat) def get_packed_indices(self, y_mask, *, lT, lW, lH): patch_size = 2 N = lT * lH * lW // (patch_size**2) assert len(y_mask) == 1 packed_indices = compute_packed_indices(N, y_mask) self.move_to_device_(packed_indices) return packed_indices def move_to_device_(self, sample): if isinstance(sample, dict): for key in sample.keys(): if isinstance(sample[key], torch.Tensor): sample[key] = sample[key].to(self.device, non_blocking=True) def run(self, args): torch.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"]) generator = torch.Generator(device=self.device) generator.manual_seed(args["seed"]) num_frames = args["num_frames"] height = args["height"] width = args["width"] batch_cfg = args["mochi_args"]["batch_cfg"] sample_steps = args["mochi_args"]["num_inference_steps"] cfg_schedule = args["mochi_args"].get("cfg_schedule") assert ( len(cfg_schedule) == sample_steps ), f"cfg_schedule must have length {sample_steps}, got {len(cfg_schedule)}" sigma_schedule = args["mochi_args"].get("sigma_schedule") if sigma_schedule: assert ( len(sigma_schedule) == sample_steps + 1 ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}" # if batch_cfg: # sample_batched = self.get_conditioning( # [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0 # ) # else: # sample = self.get_conditioning([prompt], zero_last_n_prompts=0) # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) # create z spatial_downsample = 8 temporal_downsample = 6 in_channels = 12 B = 1 C = in_channels T = (num_frames - 1) // temporal_downsample + 1 H = height // spatial_downsample W = width // spatial_downsample latent_dims = dict(lT=T, lW=W, lH=H) z = torch.randn( (B, C, T, H, W), device=self.device, generator=generator, dtype=torch.float32, ) if batch_cfg: #WIP pos_embeds = args["positive_embeds"]["embeds"].to(self.device) neg_embeds = args["negative_embeds"]["embeds"].to(self.device) pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device) neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device) print(neg_embeds.shape) y_feat = torch.cat((pos_embeds, neg_embeds)) y_mask = torch.cat((pos_attention_mask, neg_attention_mask)) zero_last_n_prompts = B# if neg_prompt == "" else 0 y_feat[-zero_last_n_prompts:] = 0 y_mask[-zero_last_n_prompts:] = False sample_batched = { "y_mask": [y_mask], "y_feat": [y_feat] } sample_batched["packed_indices"] = self.get_packed_indices( sample_batched["y_mask"], **latent_dims ) z = repeat(z, "b ... -> (repeat b) ...", repeat=2) print("sample_batched y_mask",sample_batched["y_mask"]) print("y_mask type",type(sample_batched["y_mask"])) #" print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256]) else: sample = { "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] } sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] } sample["packed_indices"] = self.get_packed_indices( sample["y_mask"], **latent_dims ) sample_null["packed_indices"] = self.get_packed_indices( sample_null["y_mask"], **latent_dims ) def model_fn(*, z, sigma, cfg_scale): self.dit.to(self.device) if batch_cfg: with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): out = self.dit(z, sigma, **sample_batched) out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) else: nonlocal sample, sample_null with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): out_cond = self.dit(z, sigma, **sample) out_uncond = self.dit(z, sigma, **sample_null) assert out_cond.shape == out_uncond.shape return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond comfy_pbar = ProgressBar(sample_steps) for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): sigma = sigma_schedule[i] dsigma = sigma - sigma_schedule[i + 1] # `pred` estimates `z_0 - eps`. pred, output_cond = model_fn( z=z, sigma=torch.full( [B] if not batch_cfg else [B * 2], sigma, device=z.device ), cfg_scale=cfg_schedule[i], ) pred = pred.to(z) output_cond = output_cond.to(z) z = z + dsigma * pred comfy_pbar.update(1) cp_rank, cp_size = get_cp_rank_size() if batch_cfg: z = z[:B] z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim self.dit.to(self.offload_device) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) logging.info(f"samples shape: {samples.shape}") return samples