Add VAE encoder

This commit is contained in:
kijai 2024-11-01 05:22:49 +02:00
parent ebd0f62d53
commit ac5de728ad
5 changed files with 577 additions and 94 deletions

View File

@ -233,6 +233,7 @@ class T2VSynthMochiModel:
num_frames = args["num_frames"] num_frames = args["num_frames"]
height = args["height"] height = args["height"]
width = args["width"] width = args["width"]
in_samples = args["samples"]
sample_steps = args["mochi_args"]["num_inference_steps"] sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule") cfg_schedule = args["mochi_args"].get("cfg_schedule")
@ -263,6 +264,8 @@ class T2VSynthMochiModel:
generator=generator, generator=generator,
dtype=torch.float32, dtype=torch.float32,
) )
if in_samples is not None:
z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-1]
sample = { sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
@ -314,6 +317,6 @@ class T2VSynthMochiModel:
comfy_pbar.update(1) comfy_pbar.update(1)
self.dit.to(self.offload_device) self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) #samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}") logging.info(f"samples shape: {z.shape}")
return samples return z

View File

@ -0,0 +1,35 @@
"""Container for latent space posterior."""
import torch
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean

View File

@ -1,5 +1,6 @@
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -7,7 +8,7 @@ from einops import rearrange
#from ..dit.joint_model.context_parallel import get_cp_rank_size #from ..dit.joint_model.context_parallel import get_cp_rank_size
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames #from ..vae.cp_conv import cp_pass_frames, gather_all_frames
from .latent_dist import LatentDistribution
def cast_tuple(t, length=1): def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length) return t if isinstance(t, tuple) else ((t,) * length)
@ -135,9 +136,13 @@ class ContextParallelConv3d(SafeConv3d):
pad_back = context_size - pad_front pad_back = context_size - pad_front
# Apply padding. # Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
return super().forward(x) return super().forward(x)
@ -221,8 +226,10 @@ class ResBlock(nn.Module):
*, *,
affine: bool = True, affine: bool = True,
attn_block: Optional[nn.Module] = None, attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True, causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -233,22 +240,22 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
ContextParallelConv3d( ContextParallelConv3d(
in_channels=channels, in_channels=channels,
out_channels=channels, out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
causal=causal, causal=causal,
), ),
norm_fn(channels, affine=affine), norm_fn(channels, affine=affine),
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
ContextParallelConv3d( ContextParallelConv3d(
in_channels=channels, in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels, out_channels=channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
causal=causal, causal=causal,
), ),
) )
@ -357,9 +364,7 @@ class Attention(nn.Module):
) )
if q.size(0) <= chunk_size: if q.size(0) <= chunk_size:
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(q, k, v, **attn_kwargs) # [B, num_heads, t, head_dim]
q, k, v, **attn_kwargs
) # [B, num_heads, t, head_dim]
else: else:
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.` # Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
# Chunks of 2**16 and up cause an error. # Chunks of 2**16 and up cause an error.
@ -421,9 +426,7 @@ class CausalUpsampleBlock(nn.Module):
out_channels * temporal_expansion * (spatial_expansion**2), out_channels * temporal_expansion * (spatial_expansion**2),
) )
self.d2st = DepthToSpaceTime( self.d2st = DepthToSpaceTime(temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion)
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
def forward(self, x): def forward(self, x):
x = self.blocks(x) x = self.blocks(x)
@ -432,60 +435,9 @@ class CausalUpsampleBlock(nn.Module):
return x return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
#attn_block = AttentionBlock(channels) if has_attention else None attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
return ResBlock(
channels, affine=True, attn_block=None, **block_kwargs
)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
ContextParallelConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
padding_mode="replicate",
bias=True,
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
@ -568,14 +520,13 @@ class Decoder(nn.Module):
assert len(num_res_blocks) == self.num_up_blocks + 2 assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = [] blocks = []
new_block_fn = partial(block_fn, padding_mode="replicate")
first_block = [ first_block = [nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))] # Input layer.
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count. # First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]): for _ in range(num_res_blocks[-1]):
first_block.append( first_block.append(
block_fn( new_block_fn(
ch[-1], ch[-1],
has_attention=has_attention[-1], has_attention=has_attention[-1],
causal=causal, causal=causal,
@ -598,6 +549,7 @@ class Decoder(nn.Module):
temporal_expansion=temporal_expansions[-i - 1], temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1], spatial_expansion=spatial_expansions[-i - 1],
causal=causal, causal=causal,
padding_mode="replicate",
**block_kwargs, **block_kwargs,
) )
blocks.append(block) blocks.append(block)
@ -607,11 +559,7 @@ class Decoder(nn.Module):
# Last block. Preserve channel count. # Last block. Preserve channel count.
last_block = [] last_block = []
for _ in range(num_res_blocks[0]): for _ in range(num_res_blocks[0]):
last_block.append( last_block.append(new_block_fn(ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs))
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
blocks.append(nn.Sequential(*last_block)) blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
@ -634,9 +582,7 @@ class Decoder(nn.Module):
if self.output_nonlinearity == "silu": if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training) x = F.silu(x, inplace=not self.training)
else: else:
assert ( assert not self.output_nonlinearity # StyleGAN3 omits the to-RGB nonlinearity.
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
@ -678,9 +624,7 @@ def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
Returns: Returns:
torch.Tensor: The blended tensor. torch.Tensor: The blended tensor.
""" """
assert ( assert a.shape == b.shape, f"Tensors must have the same shape, got {a.shape} and {b.shape}"
a.shape == b.shape
), f"Tensors must have the same shape, got {a.shape} and {b.shape}"
steps = a.size(axis) steps = a.size(axis)
# Create a weight tensor that linearly interpolates from 0 to 1 # Create a weight tensor that linearly interpolates from 0 to 1
@ -800,3 +744,270 @@ def apply_tiled(
return blend_vertical(top, bottom, out_overlap) return blend_vertical(top, bottom, out_overlap)
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}") raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
assert in_channels != out_channels
layers.append(
ContextParallelConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate",
bias=block_kwargs["bias"],
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.dtype = dtype
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[nn.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
def normalize_decoded_frames(samples):
samples = samples.float()
samples = (samples + 1.0) / 2.0
samples.clamp_(0.0, 1.0)
frames = rearrange(samples, "b c t h w -> b t h w c")
return frames
@torch.inference_mode()
def decode_latents_tiled_full(
decoder,
z,
*,
tile_sample_min_height: int = 240,
tile_sample_min_width: int = 424,
tile_overlap_factor_height: float = 0.1666,
tile_overlap_factor_width: float = 0.2,
auto_tile_size: bool = True,
frame_batch_size: int = 6,
):
B, C, T, H, W = z.shape
assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}"
tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
tile_latent_min_height = int(tile_sample_min_height / 8)
tile_latent_min_width = int(tile_sample_min_width / 8)
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height))
overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width))
blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height)
blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width)
row_limit_height = tile_sample_min_height - blend_extent_height
row_limit_width = tile_sample_min_width - blend_extent_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
pbar = tqdm(
desc="Decoding latent tiles",
total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)),
)
rows = []
for i in range(0, H, overlap_height):
row = []
for j in range(0, W, overlap_width):
temporal = []
for k in range(T // frame_batch_size):
remaining_frames = T % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
tile = decoder(tile)
temporal.append(tile)
pbar.update(1)
row.append(torch.cat(temporal, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
return normalize_decoded_frames(torch.cat(result_rows, dim=3))
@torch.inference_mode()
def decode_latents_tiled_spatial(
decoder,
z,
*,
num_tiles_w: int,
num_tiles_h: int,
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
# Use a factor of 2 times the latent downsample factor.
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
):
decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size)
assert decoded is not None, f"Failed to decode latents with tiled spatial method"
return normalize_decoded_frames(decoded)

View File

@ -0,0 +1,62 @@
import torch
# Channel-wise mean and standard deviation of VAE encoder latents
STATS = {
"mean": torch.Tensor([
-0.06730895953510081,
-0.038011381506090416,
-0.07477820912866141,
-0.05565264470995561,
0.012767231469026969,
-0.04703542746246419,
0.043896967884726704,
-0.09346305707025976,
-0.09918314763016893,
-0.008729793427399178,
-0.011931556316503654,
-0.0321993391887285,
]),
"std": torch.Tensor([
0.9263795028493863,
0.9248894543193766,
0.9393059390890617,
0.959253732819592,
0.8244560132752793,
0.917259975397747,
0.9294154431013696,
1.3720942357788521,
0.881393668867029,
0.9168315692124348,
0.9185249279345552,
0.9274757570805041,
]),
}
def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor:
"""Unnormalize latents output by Mochi's DiT to be compatible with VAE.
Run this on sampled latents before calling the VAE decoder.
Args:
latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
Returns:
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
"""
mean = STATS["mean"][:, None, None, None]
std = STATS["std"][:, None, None, None]
assert dit_outputs.ndim == 5
assert dit_outputs.size(1) == mean.size(0) == std.size(0)
return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs)
def vae_latents_to_dit_latents(vae_latents: torch.Tensor):
"""Normalize latents output by the VAE encoder to be compatible with Mochi's DiT.
E.g, for fine-tuning or video-to-video.
"""
mean = STATS["mean"][:, None, None, None]
std = STATS["std"][:, None, None, None]
assert vae_latents.ndim == 5
assert vae_latents.size(1) == mean.size(0) == std.size(0)
return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)

186
nodes.py
View File

@ -13,7 +13,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
from .mochi_preview.vae.model import Decoder from .mochi_preview.vae.model import Decoder, Encoder, add_fourier_features
from .mochi_preview.vae.vae_stats import vae_latents_to_dit_latents, dit_latents_to_vae_latents
from contextlib import nullcontext from contextlib import nullcontext
try: try:
@ -140,7 +141,6 @@ class DownloadAndLoadMochiModel:
num_res_blocks=[3, 3, 4, 6, 3], num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12, latent_dim=12,
has_attention=[False, False, False, False, False], has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False, output_norm=False,
nonlinearity="silu", nonlinearity="silu",
output_nonlinearity="silu", output_nonlinearity="silu",
@ -269,7 +269,6 @@ class MochiVAELoader:
num_res_blocks=[3, 3, 4, 6, 3], num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12, latent_dim=12,
has_attention=[False, False, False, False, False], has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False, output_norm=False,
nonlinearity="silu", nonlinearity="silu",
output_nonlinearity="silu", output_nonlinearity="silu",
@ -292,6 +291,74 @@ class MochiVAELoader:
return (vae,) return (vae,)
class MochiVAEEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
"optional": {
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
},
}
RETURN_TYPES = ("MOCHIVAE",)
RETURN_NAMES = ("mochi_vae", )
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
config = dict(
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
# Create VAE encoder
with (init_empty_weights() if is_accelerate_available else nullcontext()):
encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
dtype = dtype,
**config,
)
encoder_sd = load_torch_file(vae_path)
if is_accelerate_available:
for name, param in encoder.named_parameters():
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
else:
encoder.load_state_dict(encoder_sd, strict=True)
encoder.to(dtype).to(offload_device)
encoder.eval()
del encoder_sd
if torch_compile_args is not None:
encoder.to(device)
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
return (encoder,)
class MochiTextEncode: class MochiTextEncode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -365,6 +432,7 @@ class MochiSampler:
"optional": { "optional": {
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}), "cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}), "opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
"samples": ("LATENT", ),
} }
} }
@ -373,7 +441,7 @@ class MochiSampler:
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None): def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
mm.soft_empty_cache() mm.soft_empty_cache()
if opt_sigmas is not None: if opt_sigmas is not None:
@ -413,6 +481,7 @@ class MochiSampler:
"positive_embeds": positive, "positive_embeds": positive,
"negative_embeds": negative, "negative_embeds": negative,
"seed": seed, "seed": seed,
"samples": samples["samples"] if samples is not None else None,
} }
latents = model.run(args) latents = model.run(args)
@ -447,6 +516,7 @@ class MochiDecode:
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -574,6 +644,7 @@ class MochiDecodeSpatialTiling:
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device() intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device) samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape B, C, T, H, W = samples.shape
@ -616,6 +687,101 @@ class MochiDecodeSpatialTiling:
return (frames,) return (frames,)
class MochiImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"encoder": ("MOCHIVAE",),
"images": ("IMAGE", ),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, encoder, images):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
B, H, W, C = images.shape
images = images.unsqueeze(0) * 2 - 1
images = rearrange(images, "t b h w c -> t c b h w")
images = images.to(encoder.dtype).to(device)
print(images.shape)
encoder.to(device)
print("images before encoding", images.shape)
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
video = add_fourier_features(images)
latents = encoder(video).sample()
latents = vae_latents_to_dit_latents(latents)
print("encoder output",latents.shape)
return ({"samples": latents},)
class MochiLatentPreview:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
# "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
# "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}),
# "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
},
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images", )
FUNCTION = "sample"
CATEGORY = "PyramidFlowWrapper"
def sample(self, samples):#, seed, min_val, max_val):
mm.soft_empty_cache()
latents = samples["samples"].clone()
print("in sample", latents.shape)
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
# import random
# random.seed(seed)
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
# out_factors = latent_rgb_factors
# print(latent_rgb_factors)
latent_rgb_factors_bias = [0,0,0]
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
print("latent_rgb_factors", latent_rgb_factors.shape)
latent_images = []
for t in range(latents.shape[2]):
latent = latents[:, :, t, :, :]
latent = latent[0].permute(1, 2, 0)
latent_image = torch.nn.functional.linear(
latent,
latent_rgb_factors,
bias=latent_rgb_factors_bias
)
latent_images.append(latent_image)
latent_images = torch.stack(latent_images, dim=0)
print("latent_images", latent_images.shape)
latent_images_min = latent_images.min()
latent_images_max = latent_images.max()
latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min)
return (latent_images.float().cpu(),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
@ -624,8 +790,11 @@ NODE_CLASS_MAPPINGS = {
"MochiTextEncode": MochiTextEncode, "MochiTextEncode": MochiTextEncode,
"MochiModelLoader": MochiModelLoader, "MochiModelLoader": MochiModelLoader,
"MochiVAELoader": MochiVAELoader, "MochiVAELoader": MochiVAELoader,
"MochiVAEEncoderLoader": MochiVAEEncoderLoader,
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling, "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
"MochiTorchCompileSettings": MochiTorchCompileSettings "MochiTorchCompileSettings": MochiTorchCompileSettings,
"MochiImageEncode": MochiImageEncode,
"MochiLatentPreview": MochiLatentPreview
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model", "DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -633,7 +802,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiDecode": "Mochi Decode", "MochiDecode": "Mochi Decode",
"MochiTextEncode": "Mochi TextEncode", "MochiTextEncode": "Mochi TextEncode",
"MochiModelLoader": "Mochi Model Loader", "MochiModelLoader": "Mochi Model Loader",
"MochiVAELoader": "Mochi VAE Loader", "MochiVAELoader": "Mochi VAE Decoder Loader",
"MochiVAEEncoderLoader": "Mochi VAE Encoder Loader",
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling", "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
"MochiTorchCompileSettings": "Mochi Torch Compile Settings" "MochiTorchCompileSettings": "Mochi Torch Compile Settings",
"MochiImageEncode": "Mochi Image Encode",
"MochiLatentPreview": "Mochi Latent Preview"
} }