import torch # Channel-wise mean and standard deviation of VAE encoder latents STATS = { "mean": torch.Tensor([ -0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285, ]), "std": torch.Tensor([ 0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041, ]), } def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor: """Unnormalize latents output by Mochi's DiT to be compatible with VAE. Run this on sampled latents before calling the VAE decoder. Args: latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float Returns: torch.Tensor: [B, C_z, T_z, H_z, W_z], float """ mean = STATS["mean"][:, None, None, None] std = STATS["std"][:, None, None, None] assert dit_outputs.ndim == 5 assert dit_outputs.size(1) == mean.size(0) == std.size(0) return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs) def vae_latents_to_dit_latents(vae_latents: torch.Tensor): """Normalize latents output by the VAE encoder to be compatible with Mochi's DiT. E.g, for fine-tuning or video-to-video. """ mean = STATS["mean"][:, None, None, None] std = STATS["std"][:, None, None, None] assert vae_latents.ndim == 5 assert vae_latents.size(1) == mean.size(0) == std.size(0) return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)