Add VAE encoder

This commit is contained in:
kijai 2024-11-01 05:22:49 +02:00
parent ebd0f62d53
commit ac5de728ad
5 changed files with 577 additions and 94 deletions

View File

@ -233,6 +233,7 @@ class T2VSynthMochiModel:
num_frames = args["num_frames"]
height = args["height"]
width = args["width"]
in_samples = args["samples"]
sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule")
@ -263,6 +264,8 @@ class T2VSynthMochiModel:
generator=generator,
dtype=torch.float32,
)
if in_samples is not None:
z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-1]
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
@ -314,6 +317,6 @@ class T2VSynthMochiModel:
comfy_pbar.update(1)
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}")
return samples
#samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {z.shape}")
return z

View File

@ -0,0 +1,35 @@
"""Container for latent space posterior."""
import torch
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean

View File

@ -1,5 +1,6 @@
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -7,7 +8,7 @@ from einops import rearrange
#from ..dit.joint_model.context_parallel import get_cp_rank_size
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
from .latent_dist import LatentDistribution
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
@ -135,9 +136,13 @@ class ContextParallelConv3d(SafeConv3d):
pad_back = context_size - pad_front
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
if self.context_parallel:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
return super().forward(x)
@ -221,8 +226,10 @@ class ResBlock(nn.Module):
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
):
super().__init__()
self.channels = channels
@ -233,22 +240,22 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True),
ContextParallelConv3d(
in_channels=channels,
out_channels=channels,
out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
bias=bias,
causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
ContextParallelConv3d(
in_channels=channels,
in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
bias=bias,
causal=causal,
),
)
@ -357,9 +364,7 @@ class Attention(nn.Module):
)
if q.size(0) <= chunk_size:
x = F.scaled_dot_product_attention(
q, k, v, **attn_kwargs
) # [B, num_heads, t, head_dim]
x = F.scaled_dot_product_attention(q, k, v, **attn_kwargs) # [B, num_heads, t, head_dim]
else:
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
# Chunks of 2**16 and up cause an error.
@ -421,9 +426,7 @@ class CausalUpsampleBlock(nn.Module):
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.d2st = DepthToSpaceTime(
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
)
self.d2st = DepthToSpaceTime(temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion)
def forward(self, x):
x = self.blocks(x)
@ -432,60 +435,9 @@ class CausalUpsampleBlock(nn.Module):
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
#attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=None, **block_kwargs
)
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
# Change the channel count in the strided convolution.
# This lets the ResBlock have uniform channel count,
# as in ConvNeXt.
assert in_channels != out_channels
layers.append(
ContextParallelConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
padding_mode="replicate",
bias=True,
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
@ -568,14 +520,13 @@ class Decoder(nn.Module):
assert len(num_res_blocks) == self.num_up_blocks + 2
blocks = []
new_block_fn = partial(block_fn, padding_mode="replicate")
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
first_block = [nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
first_block.append(
block_fn(
new_block_fn(
ch[-1],
has_attention=has_attention[-1],
causal=causal,
@ -598,6 +549,7 @@ class Decoder(nn.Module):
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
causal=causal,
padding_mode="replicate",
**block_kwargs,
)
blocks.append(block)
@ -607,11 +559,7 @@ class Decoder(nn.Module):
# Last block. Preserve channel count.
last_block = []
for _ in range(num_res_blocks[0]):
last_block.append(
block_fn(
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
)
)
last_block.append(new_block_fn(ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs))
blocks.append(nn.Sequential(*last_block))
self.blocks = nn.ModuleList(blocks)
@ -634,9 +582,7 @@ class Decoder(nn.Module):
if self.output_nonlinearity == "silu":
x = F.silu(x, inplace=not self.training)
else:
assert (
not self.output_nonlinearity
) # StyleGAN3 omits the to-RGB nonlinearity.
assert not self.output_nonlinearity # StyleGAN3 omits the to-RGB nonlinearity.
return self.output_proj(x).contiguous()
@ -678,9 +624,7 @@ def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
Returns:
torch.Tensor: The blended tensor.
"""
assert (
a.shape == b.shape
), f"Tensors must have the same shape, got {a.shape} and {b.shape}"
assert a.shape == b.shape, f"Tensors must have the same shape, got {a.shape} and {b.shape}"
steps = a.size(axis)
# Create a weight tensor that linearly interpolates from 0 to 1
@ -800,3 +744,270 @@ def apply_tiled(
return blend_vertical(top, bottom, out_overlap)
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
class DownsampleBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_res_blocks,
*,
temporal_reduction=2,
spatial_reduction=2,
**block_kwargs,
):
"""
Downsample block for the VAE encoder.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks.
temporal_reduction: Temporal reduction factor.
spatial_reduction: Spatial reduction factor.
"""
super().__init__()
layers = []
assert in_channels != out_channels
layers.append(
ContextParallelConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate",
bias=block_kwargs["bias"],
)
)
for _ in range(num_res_blocks):
layers.append(block_fn(out_channels, **block_kwargs))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.dtype = dtype
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[nn.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
def normalize_decoded_frames(samples):
samples = samples.float()
samples = (samples + 1.0) / 2.0
samples.clamp_(0.0, 1.0)
frames = rearrange(samples, "b c t h w -> b t h w c")
return frames
@torch.inference_mode()
def decode_latents_tiled_full(
decoder,
z,
*,
tile_sample_min_height: int = 240,
tile_sample_min_width: int = 424,
tile_overlap_factor_height: float = 0.1666,
tile_overlap_factor_width: float = 0.2,
auto_tile_size: bool = True,
frame_batch_size: int = 6,
):
B, C, T, H, W = z.shape
assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}"
tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
tile_latent_min_height = int(tile_sample_min_height / 8)
tile_latent_min_width = int(tile_sample_min_width / 8)
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height))
overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width))
blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height)
blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width)
row_limit_height = tile_sample_min_height - blend_extent_height
row_limit_width = tile_sample_min_width - blend_extent_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
pbar = tqdm(
desc="Decoding latent tiles",
total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)),
)
rows = []
for i in range(0, H, overlap_height):
row = []
for j in range(0, W, overlap_width):
temporal = []
for k in range(T // frame_batch_size):
remaining_frames = T % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
tile = decoder(tile)
temporal.append(tile)
pbar.update(1)
row.append(torch.cat(temporal, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
return normalize_decoded_frames(torch.cat(result_rows, dim=3))
@torch.inference_mode()
def decode_latents_tiled_spatial(
decoder,
z,
*,
num_tiles_w: int,
num_tiles_h: int,
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
# Use a factor of 2 times the latent downsample factor.
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
):
decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size)
assert decoded is not None, f"Failed to decode latents with tiled spatial method"
return normalize_decoded_frames(decoded)

View File

@ -0,0 +1,62 @@
import torch
# Channel-wise mean and standard deviation of VAE encoder latents
STATS = {
"mean": torch.Tensor([
-0.06730895953510081,
-0.038011381506090416,
-0.07477820912866141,
-0.05565264470995561,
0.012767231469026969,
-0.04703542746246419,
0.043896967884726704,
-0.09346305707025976,
-0.09918314763016893,
-0.008729793427399178,
-0.011931556316503654,
-0.0321993391887285,
]),
"std": torch.Tensor([
0.9263795028493863,
0.9248894543193766,
0.9393059390890617,
0.959253732819592,
0.8244560132752793,
0.917259975397747,
0.9294154431013696,
1.3720942357788521,
0.881393668867029,
0.9168315692124348,
0.9185249279345552,
0.9274757570805041,
]),
}
def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor:
"""Unnormalize latents output by Mochi's DiT to be compatible with VAE.
Run this on sampled latents before calling the VAE decoder.
Args:
latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
Returns:
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
"""
mean = STATS["mean"][:, None, None, None]
std = STATS["std"][:, None, None, None]
assert dit_outputs.ndim == 5
assert dit_outputs.size(1) == mean.size(0) == std.size(0)
return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs)
def vae_latents_to_dit_latents(vae_latents: torch.Tensor):
"""Normalize latents output by the VAE encoder to be compatible with Mochi's DiT.
E.g, for fine-tuning or video-to-video.
"""
mean = STATS["mean"][:, None, None, None]
std = STATS["std"][:, None, None, None]
assert vae_latents.ndim == 5
assert vae_latents.size(1) == mean.size(0) == std.size(0)
return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)

186
nodes.py
View File

@ -13,7 +13,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
log = logging.getLogger(__name__)
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
from .mochi_preview.vae.model import Decoder
from .mochi_preview.vae.model import Decoder, Encoder, add_fourier_features
from .mochi_preview.vae.vae_stats import vae_latents_to_dit_latents, dit_latents_to_vae_latents
from contextlib import nullcontext
try:
@ -140,7 +141,6 @@ class DownloadAndLoadMochiModel:
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
@ -269,7 +269,6 @@ class MochiVAELoader:
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
has_attention=[False, False, False, False, False],
padding_mode="replicate",
output_norm=False,
nonlinearity="silu",
output_nonlinearity="silu",
@ -292,6 +291,74 @@ class MochiVAELoader:
return (vae,)
class MochiVAEEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
},
"optional": {
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
},
}
RETURN_TYPES = ("MOCHIVAE",)
RETURN_NAMES = ("mochi_vae", )
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
config = dict(
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
# Create VAE encoder
with (init_empty_weights() if is_accelerate_available else nullcontext()):
encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
dtype = dtype,
**config,
)
encoder_sd = load_torch_file(vae_path)
if is_accelerate_available:
for name, param in encoder.named_parameters():
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
else:
encoder.load_state_dict(encoder_sd, strict=True)
encoder.to(dtype).to(offload_device)
encoder.eval()
del encoder_sd
if torch_compile_args is not None:
encoder.to(device)
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
return (encoder,)
class MochiTextEncode:
@classmethod
def INPUT_TYPES(s):
@ -365,6 +432,7 @@ class MochiSampler:
"optional": {
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
"samples": ("LATENT", ),
}
}
@ -373,7 +441,7 @@ class MochiSampler:
FUNCTION = "process"
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None):
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
mm.soft_empty_cache()
if opt_sigmas is not None:
@ -413,6 +481,7 @@ class MochiSampler:
"positive_embeds": positive,
"negative_embeds": negative,
"seed": seed,
"samples": samples["samples"] if samples is not None else None,
}
latents = model.run(args)
@ -447,6 +516,7 @@ class MochiDecode:
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape
@ -574,6 +644,7 @@ class MochiDecodeSpatialTiling:
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = dit_latents_to_vae_latents(samples)
samples = samples.to(vae.dtype).to(device)
B, C, T, H, W = samples.shape
@ -615,7 +686,102 @@ class MochiDecodeSpatialTiling:
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
return (frames,)
class MochiImageEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"encoder": ("MOCHIVAE",),
"images": ("IMAGE", ),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "decode"
CATEGORY = "MochiWrapper"
def decode(self, encoder, images):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
B, H, W, C = images.shape
images = images.unsqueeze(0) * 2 - 1
images = rearrange(images, "t b h w c -> t c b h w")
images = images.to(encoder.dtype).to(device)
print(images.shape)
encoder.to(device)
print("images before encoding", images.shape)
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
video = add_fourier_features(images)
latents = encoder(video).sample()
latents = vae_latents_to_dit_latents(latents)
print("encoder output",latents.shape)
return ({"samples": latents},)
class MochiLatentPreview:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
# "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
# "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}),
# "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
},
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images", )
FUNCTION = "sample"
CATEGORY = "PyramidFlowWrapper"
def sample(self, samples):#, seed, min_val, max_val):
mm.soft_empty_cache()
latents = samples["samples"].clone()
print("in sample", latents.shape)
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
# import random
# random.seed(seed)
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
# out_factors = latent_rgb_factors
# print(latent_rgb_factors)
latent_rgb_factors_bias = [0,0,0]
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
print("latent_rgb_factors", latent_rgb_factors.shape)
latent_images = []
for t in range(latents.shape[2]):
latent = latents[:, :, t, :, :]
latent = latent[0].permute(1, 2, 0)
latent_image = torch.nn.functional.linear(
latent,
latent_rgb_factors,
bias=latent_rgb_factors_bias
)
latent_images.append(latent_image)
latent_images = torch.stack(latent_images, dim=0)
print("latent_images", latent_images.shape)
latent_images_min = latent_images.min()
latent_images_max = latent_images.max()
latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min)
return (latent_images.float().cpu(),)
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
@ -624,8 +790,11 @@ NODE_CLASS_MAPPINGS = {
"MochiTextEncode": MochiTextEncode,
"MochiModelLoader": MochiModelLoader,
"MochiVAELoader": MochiVAELoader,
"MochiVAEEncoderLoader": MochiVAEEncoderLoader,
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
"MochiTorchCompileSettings": MochiTorchCompileSettings
"MochiTorchCompileSettings": MochiTorchCompileSettings,
"MochiImageEncode": MochiImageEncode,
"MochiLatentPreview": MochiLatentPreview
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -633,7 +802,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiDecode": "Mochi Decode",
"MochiTextEncode": "Mochi TextEncode",
"MochiModelLoader": "Mochi Model Loader",
"MochiVAELoader": "Mochi VAE Loader",
"MochiVAELoader": "Mochi VAE Decoder Loader",
"MochiVAEEncoderLoader": "Mochi VAE Encoder Loader",
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
"MochiTorchCompileSettings": "Mochi Torch Compile Settings"
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
"MochiImageEncode": "Mochi Image Encode",
"MochiLatentPreview": "Mochi Latent Preview"
}