diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 476ab12..e5d074a 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -233,6 +233,7 @@ class T2VSynthMochiModel: num_frames = args["num_frames"] height = args["height"] width = args["width"] + in_samples = args["samples"] sample_steps = args["mochi_args"]["num_inference_steps"] cfg_schedule = args["mochi_args"].get("cfg_schedule") @@ -263,6 +264,8 @@ class T2VSynthMochiModel: generator=generator, dtype=torch.float32, ) + if in_samples is not None: + z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-1] sample = { "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], @@ -314,6 +317,6 @@ class T2VSynthMochiModel: comfy_pbar.update(1) self.dit.to(self.offload_device) - samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) - logging.info(f"samples shape: {samples.shape}") - return samples + #samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) + logging.info(f"samples shape: {z.shape}") + return z diff --git a/mochi_preview/vae/latent_dist.py b/mochi_preview/vae/latent_dist.py new file mode 100644 index 0000000..d99eac0 --- /dev/null +++ b/mochi_preview/vae/latent_dist.py @@ -0,0 +1,35 @@ +"""Container for latent space posterior.""" + +import torch + + +class LatentDistribution: + def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): + """Initialize latent distribution. + + Args: + mean: Mean of the distribution. Shape: [B, C, T, H, W]. + logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W]. + """ + assert mean.shape == logvar.shape + self.mean = mean + self.logvar = logvar + + def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None): + if temperature == 0.0: + return self.mean + + if noise is None: + noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator) + else: + assert noise.device == self.mean.device + noise = noise.to(self.mean.dtype) + + if temperature != 1.0: + raise NotImplementedError(f"Temperature {temperature} is not supported.") + + # Just Gaussian sample with no scaling of variance. + return noise * torch.exp(self.logvar * 0.5) + self.mean + + def mode(self): + return self.mean diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 9514df1..4085623 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -1,5 +1,6 @@ from typing import Callable, List, Optional, Tuple, Union - +from functools import partial +import math import torch import torch.nn as nn import torch.nn.functional as F @@ -7,7 +8,7 @@ from einops import rearrange #from ..dit.joint_model.context_parallel import get_cp_rank_size #from ..vae.cp_conv import cp_pass_frames, gather_all_frames - +from .latent_dist import LatentDistribution def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) @@ -135,9 +136,13 @@ class ContextParallelConv3d(SafeConv3d): pad_back = context_size - pad_front # Apply padding. - assert self.padding_mode == "replicate" # DEBUG mode = "constant" if self.padding_mode == "zeros" else self.padding_mode - x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + if self.context_parallel: + x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + else: + x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + + return super().forward(x) @@ -221,8 +226,10 @@ class ResBlock(nn.Module): *, affine: bool = True, attn_block: Optional[nn.Module] = None, - padding_mode: str = "replicate", causal: bool = True, + prune_bottleneck: bool = False, + padding_mode: str, + bias: bool = True, ): super().__init__() self.channels = channels @@ -233,22 +240,22 @@ class ResBlock(nn.Module): nn.SiLU(inplace=True), ContextParallelConv3d( in_channels=channels, - out_channels=channels, + out_channels=channels // 2 if prune_bottleneck else channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding_mode=padding_mode, - bias=True, + bias=bias, causal=causal, ), norm_fn(channels, affine=affine), nn.SiLU(inplace=True), ContextParallelConv3d( - in_channels=channels, + in_channels=channels // 2 if prune_bottleneck else channels, out_channels=channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding_mode=padding_mode, - bias=True, + bias=bias, causal=causal, ), ) @@ -357,9 +364,7 @@ class Attention(nn.Module): ) if q.size(0) <= chunk_size: - x = F.scaled_dot_product_attention( - q, k, v, **attn_kwargs - ) # [B, num_heads, t, head_dim] + x = F.scaled_dot_product_attention(q, k, v, **attn_kwargs) # [B, num_heads, t, head_dim] else: # Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.` # Chunks of 2**16 and up cause an error. @@ -421,9 +426,7 @@ class CausalUpsampleBlock(nn.Module): out_channels * temporal_expansion * (spatial_expansion**2), ) - self.d2st = DepthToSpaceTime( - temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion - ) + self.d2st = DepthToSpaceTime(temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion) def forward(self, x): x = self.blocks(x) @@ -432,60 +435,9 @@ class CausalUpsampleBlock(nn.Module): return x -def block_fn(channels, *, has_attention: bool = False, **block_kwargs): - #attn_block = AttentionBlock(channels) if has_attention else None - - return ResBlock( - channels, affine=True, attn_block=None, **block_kwargs - ) - - -class DownsampleBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - num_res_blocks, - *, - temporal_reduction=2, - spatial_reduction=2, - **block_kwargs, - ): - """ - Downsample block for the VAE encoder. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - num_res_blocks: Number of residual blocks. - temporal_reduction: Temporal reduction factor. - spatial_reduction: Spatial reduction factor. - """ - super().__init__() - layers = [] - - # Change the channel count in the strided convolution. - # This lets the ResBlock have uniform channel count, - # as in ConvNeXt. - assert in_channels != out_channels - layers.append( - ContextParallelConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), - stride=(temporal_reduction, spatial_reduction, spatial_reduction), - padding_mode="replicate", - bias=True, - ) - ) - - for _ in range(num_res_blocks): - layers.append(block_fn(out_channels, **block_kwargs)) - - self.layers = nn.Sequential(*layers) - - def forward(self, x): - return self.layers(x) +def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs): + attn_block = AttentionBlock(channels) if has_attention else None + return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs) def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): @@ -568,14 +520,13 @@ class Decoder(nn.Module): assert len(num_res_blocks) == self.num_up_blocks + 2 blocks = [] + new_block_fn = partial(block_fn, padding_mode="replicate") - first_block = [ - nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) - ] # Input layer. + first_block = [nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))] # Input layer. # First set of blocks preserve channel count. for _ in range(num_res_blocks[-1]): first_block.append( - block_fn( + new_block_fn( ch[-1], has_attention=has_attention[-1], causal=causal, @@ -598,6 +549,7 @@ class Decoder(nn.Module): temporal_expansion=temporal_expansions[-i - 1], spatial_expansion=spatial_expansions[-i - 1], causal=causal, + padding_mode="replicate", **block_kwargs, ) blocks.append(block) @@ -607,11 +559,7 @@ class Decoder(nn.Module): # Last block. Preserve channel count. last_block = [] for _ in range(num_res_blocks[0]): - last_block.append( - block_fn( - ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs - ) - ) + last_block.append(new_block_fn(ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs)) blocks.append(nn.Sequential(*last_block)) self.blocks = nn.ModuleList(blocks) @@ -634,9 +582,7 @@ class Decoder(nn.Module): if self.output_nonlinearity == "silu": x = F.silu(x, inplace=not self.training) else: - assert ( - not self.output_nonlinearity - ) # StyleGAN3 omits the to-RGB nonlinearity. + assert not self.output_nonlinearity # StyleGAN3 omits the to-RGB nonlinearity. return self.output_proj(x).contiguous() @@ -678,9 +624,7 @@ def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor: Returns: torch.Tensor: The blended tensor. """ - assert ( - a.shape == b.shape - ), f"Tensors must have the same shape, got {a.shape} and {b.shape}" + assert a.shape == b.shape, f"Tensors must have the same shape, got {a.shape} and {b.shape}" steps = a.size(axis) # Create a weight tensor that linearly interpolates from 0 to 1 @@ -800,3 +744,270 @@ def apply_tiled( return blend_vertical(top, bottom, out_overlap) raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}") + + +class DownsampleBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks, + *, + temporal_reduction=2, + spatial_reduction=2, + **block_kwargs, + ): + """ + Downsample block for the VAE encoder. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + num_res_blocks: Number of residual blocks. + temporal_reduction: Temporal reduction factor. + spatial_reduction: Spatial reduction factor. + """ + super().__init__() + layers = [] + + assert in_channels != out_channels + layers.append( + ContextParallelConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), + stride=(temporal_reduction, spatial_reduction, spatial_reduction), + # First layer in each block always uses replicate padding + padding_mode="replicate", + bias=block_kwargs["bias"], + ) + ) + + for _ in range(num_res_blocks): + layers.append(block_fn(out_channels, **block_kwargs)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class Encoder(nn.Module): + def __init__( + self, + *, + in_channels: int, + base_channels: int, + channel_multipliers: List[int], + num_res_blocks: List[int], + latent_dim: int, + temporal_reductions: List[int], + spatial_reductions: List[int], + prune_bottlenecks: List[bool], + has_attentions: List[bool], + affine: bool = True, + bias: bool = True, + input_is_conv_1x1: bool = False, + padding_mode: str, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.temporal_reductions = temporal_reductions + self.spatial_reductions = spatial_reductions + self.base_channels = base_channels + self.channel_multipliers = channel_multipliers + self.num_res_blocks = num_res_blocks + self.latent_dim = latent_dim + self.dtype = dtype + + ch = [mult * base_channels for mult in channel_multipliers] + num_down_blocks = len(ch) - 1 + assert len(num_res_blocks) == num_down_blocks + 2 + + layers = ( + [nn.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)] + if not input_is_conv_1x1 + else [Conv1x1(in_channels, ch[0])] + ) + + assert len(prune_bottlenecks) == num_down_blocks + 2 + assert len(has_attentions) == num_down_blocks + 2 + block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias) + + for _ in range(num_res_blocks[0]): + layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0])) + prune_bottlenecks = prune_bottlenecks[1:] + has_attentions = has_attentions[1:] + + assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1 + for i in range(num_down_blocks): + layer = DownsampleBlock( + ch[i], + ch[i + 1], + num_res_blocks=num_res_blocks[i + 1], + temporal_reduction=temporal_reductions[i], + spatial_reduction=spatial_reductions[i], + prune_bottleneck=prune_bottlenecks[i], + has_attention=has_attentions[i], + affine=affine, + bias=bias, + padding_mode=padding_mode, + ) + + layers.append(layer) + + # Additional blocks. + for _ in range(num_res_blocks[-1]): + layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1])) + + self.layers = nn.Sequential(*layers) + + # Output layers. + self.output_norm = norm_fn(ch[-1]) + self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False) + + @property + def temporal_downsample(self): + return math.prod(self.temporal_reductions) + + @property + def spatial_downsample(self): + return math.prod(self.spatial_reductions) + + def forward(self, x) -> LatentDistribution: + """Forward pass. + + Args: + x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1] + + Returns: + means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1]. + h = H // 8, w = W // 8, t - 1 = (T - 1) // 6 + logvar: Shape: [B, latent_dim, t, h, w]. + """ + assert x.ndim == 5, f"Expected 5D input, got {x.shape}" + + x = self.layers(x) + + x = self.output_norm(x) + x = F.silu(x, inplace=True) + x = self.output_proj(x) + + means, logvar = torch.chunk(x, 2, dim=1) + + assert means.ndim == 5 + assert logvar.shape == means.shape + assert means.size(1) == self.latent_dim + + return LatentDistribution(means, logvar) + + +def normalize_decoded_frames(samples): + samples = samples.float() + samples = (samples + 1.0) / 2.0 + samples.clamp_(0.0, 1.0) + frames = rearrange(samples, "b c t h w -> b t h w c") + return frames + +@torch.inference_mode() +def decode_latents_tiled_full( + decoder, + z, + *, + tile_sample_min_height: int = 240, + tile_sample_min_width: int = 424, + tile_overlap_factor_height: float = 0.1666, + tile_overlap_factor_width: float = 0.2, + auto_tile_size: bool = True, + frame_batch_size: int = 6, +): + B, C, T, H, W = z.shape + assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}" + + tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8 + tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8 + + tile_latent_min_height = int(tile_sample_min_height / 8) + tile_latent_min_width = int(tile_sample_min_width / 8) + + def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height)) + overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width)) + blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height) + blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width) + row_limit_height = tile_sample_min_height - blend_extent_height + row_limit_width = tile_sample_min_width - blend_extent_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + pbar = tqdm( + desc="Decoding latent tiles", + total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)), + ) + rows = [] + for i in range(0, H, overlap_height): + row = [] + for j in range(0, W, overlap_width): + temporal = [] + for k in range(T // frame_batch_size): + remaining_frames = T % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tile = decoder(tile) + temporal.append(tile) + pbar.update(1) + row.append(torch.cat(temporal, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + return normalize_decoded_frames(torch.cat(result_rows, dim=3)) + + +@torch.inference_mode() +def decode_latents_tiled_spatial( + decoder, + z, + *, + num_tiles_w: int, + num_tiles_h: int, + overlap: int = 0, # Number of pixel of overlap between adjacent tiles. + # Use a factor of 2 times the latent downsample factor. + min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing. +): + decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size) + assert decoded is not None, f"Failed to decode latents with tiled spatial method" + return normalize_decoded_frames(decoded) \ No newline at end of file diff --git a/mochi_preview/vae/vae_stats.py b/mochi_preview/vae/vae_stats.py new file mode 100644 index 0000000..276db7b --- /dev/null +++ b/mochi_preview/vae/vae_stats.py @@ -0,0 +1,62 @@ +import torch + +# Channel-wise mean and standard deviation of VAE encoder latents +STATS = { + "mean": torch.Tensor([ + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, + ]), + "std": torch.Tensor([ + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, + ]), +} + +def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor: + """Unnormalize latents output by Mochi's DiT to be compatible with VAE. + Run this on sampled latents before calling the VAE decoder. + + Args: + latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float + + Returns: + torch.Tensor: [B, C_z, T_z, H_z, W_z], float + """ + mean = STATS["mean"][:, None, None, None] + std = STATS["std"][:, None, None, None] + + assert dit_outputs.ndim == 5 + assert dit_outputs.size(1) == mean.size(0) == std.size(0) + return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs) + + +def vae_latents_to_dit_latents(vae_latents: torch.Tensor): + """Normalize latents output by the VAE encoder to be compatible with Mochi's DiT. + E.g, for fine-tuning or video-to-video. + """ + mean = STATS["mean"][:, None, None, None] + std = STATS["std"][:, None, None, None] + + assert vae_latents.ndim == 5 + assert vae_latents.size(1) == mean.size(0) == std.size(0) + return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents) diff --git a/nodes.py b/nodes.py index ca971a8..8188bd8 100644 --- a/nodes.py +++ b/nodes.py @@ -13,7 +13,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %( log = logging.getLogger(__name__) from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel -from .mochi_preview.vae.model import Decoder +from .mochi_preview.vae.model import Decoder, Encoder, add_fourier_features +from .mochi_preview.vae.vae_stats import vae_latents_to_dit_latents, dit_latents_to_vae_latents from contextlib import nullcontext try: @@ -140,7 +141,6 @@ class DownloadAndLoadMochiModel: num_res_blocks=[3, 3, 4, 6, 3], latent_dim=12, has_attention=[False, False, False, False, False], - padding_mode="replicate", output_norm=False, nonlinearity="silu", output_nonlinearity="silu", @@ -269,7 +269,6 @@ class MochiVAELoader: num_res_blocks=[3, 3, 4, 6, 3], latent_dim=12, has_attention=[False, False, False, False, False], - padding_mode="replicate", output_norm=False, nonlinearity="silu", output_nonlinearity="silu", @@ -292,6 +291,74 @@ class MochiVAELoader: return (vae,) +class MochiVAEEncoderLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}), + }, + "optional": { + "torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), + "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}), + }, + } + + RETURN_TYPES = ("MOCHIVAE",) + RETURN_NAMES = ("mochi_vae", ) + FUNCTION = "loadmodel" + CATEGORY = "MochiWrapper" + + def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"): + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + mm.soft_empty_cache() + + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + + config = dict( + prune_bottlenecks=[False, False, False, False, False], + has_attentions=[False, True, True, True, True], + affine=True, + bias=True, + input_is_conv_1x1=True, + padding_mode="replicate" + ) + + vae_path = folder_paths.get_full_path_or_raise("vae", model_name) + + + # Create VAE encoder + with (init_empty_weights() if is_accelerate_available else nullcontext()): + encoder = Encoder( + in_channels=15, + base_channels=64, + channel_multipliers=[1, 2, 4, 6], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + temporal_reductions=[1, 2, 3], + spatial_reductions=[2, 2, 2], + dtype = dtype, + **config, + ) + + encoder_sd = load_torch_file(vae_path) + if is_accelerate_available: + for name, param in encoder.named_parameters(): + set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name]) + else: + encoder.load_state_dict(encoder_sd, strict=True) + encoder.to(dtype).to(offload_device) + encoder.eval() + del encoder_sd + + if torch_compile_args is not None: + encoder.to(device) + encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"]) + + return (encoder,) + class MochiTextEncode: @classmethod def INPUT_TYPES(s): @@ -365,6 +432,7 @@ class MochiSampler: "optional": { "cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}), "opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}), + "samples": ("LATENT", ), } } @@ -373,7 +441,7 @@ class MochiSampler: FUNCTION = "process" CATEGORY = "MochiWrapper" - def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None): + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None): mm.soft_empty_cache() if opt_sigmas is not None: @@ -413,6 +481,7 @@ class MochiSampler: "positive_embeds": positive, "negative_embeds": negative, "seed": seed, + "samples": samples["samples"] if samples is not None else None, } latents = model.run(args) @@ -447,6 +516,7 @@ class MochiDecode: offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] + samples = dit_latents_to_vae_latents(samples) samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape @@ -574,6 +644,7 @@ class MochiDecodeSpatialTiling: offload_device = mm.unet_offload_device() intermediate_device = mm.intermediate_device() samples = samples["samples"] + samples = dit_latents_to_vae_latents(samples) samples = samples.to(vae.dtype).to(device) B, C, T, H, W = samples.shape @@ -615,7 +686,102 @@ class MochiDecodeSpatialTiling: frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device) return (frames,) + +class MochiImageEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "encoder": ("MOCHIVAE",), + "images": ("IMAGE", ), + }, + } + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) + FUNCTION = "decode" + CATEGORY = "MochiWrapper" + + def decode(self, encoder, images): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + intermediate_device = mm.intermediate_device() + + B, H, W, C = images.shape + + images = images.unsqueeze(0) * 2 - 1 + images = rearrange(images, "t b h w c -> t c b h w") + images = images.to(encoder.dtype).to(device) + print(images.shape) + + encoder.to(device) + print("images before encoding", images.shape) + with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype): + video = add_fourier_features(images) + latents = encoder(video).sample() + latents = vae_latents_to_dit_latents(latents) + print("encoder output",latents.shape) + + return ({"samples": latents},) + +class MochiLatentPreview: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + # "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + # "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}), + # "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("images", ) + FUNCTION = "sample" + CATEGORY = "PyramidFlowWrapper" + + def sample(self, samples):#, seed, min_val, max_val): + mm.soft_empty_cache() + + latents = samples["samples"].clone() + print("in sample", latents.shape) + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]] + + # import random + # random.seed(seed) + # latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)] + # out_factors = latent_rgb_factors + # print(latent_rgb_factors) + + + latent_rgb_factors_bias = [0,0,0] + + latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1) + latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) + + print("latent_rgb_factors", latent_rgb_factors.shape) + + latent_images = [] + for t in range(latents.shape[2]): + latent = latents[:, :, t, :, :] + latent = latent[0].permute(1, 2, 0) + latent_image = torch.nn.functional.linear( + latent, + latent_rgb_factors, + bias=latent_rgb_factors_bias + ) + latent_images.append(latent_image) + latent_images = torch.stack(latent_images, dim=0) + print("latent_images", latent_images.shape) + latent_images_min = latent_images.min() + latent_images_max = latent_images.max() + latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min) + + return (latent_images.float().cpu(),) NODE_CLASS_MAPPINGS = { "DownloadAndLoadMochiModel": DownloadAndLoadMochiModel, @@ -624,8 +790,11 @@ NODE_CLASS_MAPPINGS = { "MochiTextEncode": MochiTextEncode, "MochiModelLoader": MochiModelLoader, "MochiVAELoader": MochiVAELoader, + "MochiVAEEncoderLoader": MochiVAEEncoderLoader, "MochiDecodeSpatialTiling": MochiDecodeSpatialTiling, - "MochiTorchCompileSettings": MochiTorchCompileSettings + "MochiTorchCompileSettings": MochiTorchCompileSettings, + "MochiImageEncode": MochiImageEncode, + "MochiLatentPreview": MochiLatentPreview } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", @@ -633,7 +802,10 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MochiDecode": "Mochi Decode", "MochiTextEncode": "Mochi TextEncode", "MochiModelLoader": "Mochi Model Loader", - "MochiVAELoader": "Mochi VAE Loader", + "MochiVAELoader": "Mochi VAE Decoder Loader", + "MochiVAEEncoderLoader": "Mochi VAE Encoder Loader", "MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling", - "MochiTorchCompileSettings": "Mochi Torch Compile Settings" + "MochiTorchCompileSettings": "Mochi Torch Compile Settings", + "MochiImageEncode": "Mochi Image Encode", + "MochiLatentPreview": "Mochi Latent Preview" }