63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
import torch
|
|
|
|
# Channel-wise mean and standard deviation of VAE encoder latents
|
|
STATS = {
|
|
"mean": torch.Tensor([
|
|
-0.06730895953510081,
|
|
-0.038011381506090416,
|
|
-0.07477820912866141,
|
|
-0.05565264470995561,
|
|
0.012767231469026969,
|
|
-0.04703542746246419,
|
|
0.043896967884726704,
|
|
-0.09346305707025976,
|
|
-0.09918314763016893,
|
|
-0.008729793427399178,
|
|
-0.011931556316503654,
|
|
-0.0321993391887285,
|
|
]),
|
|
"std": torch.Tensor([
|
|
0.9263795028493863,
|
|
0.9248894543193766,
|
|
0.9393059390890617,
|
|
0.959253732819592,
|
|
0.8244560132752793,
|
|
0.917259975397747,
|
|
0.9294154431013696,
|
|
1.3720942357788521,
|
|
0.881393668867029,
|
|
0.9168315692124348,
|
|
0.9185249279345552,
|
|
0.9274757570805041,
|
|
]),
|
|
}
|
|
|
|
def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor:
|
|
"""Unnormalize latents output by Mochi's DiT to be compatible with VAE.
|
|
Run this on sampled latents before calling the VAE decoder.
|
|
|
|
Args:
|
|
latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
|
|
|
|
Returns:
|
|
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
|
|
"""
|
|
mean = STATS["mean"][:, None, None, None]
|
|
std = STATS["std"][:, None, None, None]
|
|
|
|
assert dit_outputs.ndim == 5
|
|
assert dit_outputs.size(1) == mean.size(0) == std.size(0)
|
|
return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs)
|
|
|
|
|
|
def vae_latents_to_dit_latents(vae_latents: torch.Tensor):
|
|
"""Normalize latents output by the VAE encoder to be compatible with Mochi's DiT.
|
|
E.g, for fine-tuning or video-to-video.
|
|
"""
|
|
mean = STATS["mean"][:, None, None, None]
|
|
std = STATS["std"][:, None, None, None]
|
|
|
|
assert vae_latents.ndim == 5
|
|
assert vae_latents.size(1) == mean.size(0) == std.size(0)
|
|
return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)
|