Add VAE encoder
This commit is contained in:
parent
ebd0f62d53
commit
ac5de728ad
@ -233,6 +233,7 @@ class T2VSynthMochiModel:
|
|||||||
num_frames = args["num_frames"]
|
num_frames = args["num_frames"]
|
||||||
height = args["height"]
|
height = args["height"]
|
||||||
width = args["width"]
|
width = args["width"]
|
||||||
|
in_samples = args["samples"]
|
||||||
|
|
||||||
sample_steps = args["mochi_args"]["num_inference_steps"]
|
sample_steps = args["mochi_args"]["num_inference_steps"]
|
||||||
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
||||||
@ -263,6 +264,8 @@ class T2VSynthMochiModel:
|
|||||||
generator=generator,
|
generator=generator,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
if in_samples is not None:
|
||||||
|
z = z * sigma_schedule[0] + in_samples.to(self.device) * sigma_schedule[-1]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||||
@ -314,6 +317,6 @@ class T2VSynthMochiModel:
|
|||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
#samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
logging.info(f"samples shape: {samples.shape}")
|
logging.info(f"samples shape: {z.shape}")
|
||||||
return samples
|
return z
|
||||||
|
|||||||
35
mochi_preview/vae/latent_dist.py
Normal file
35
mochi_preview/vae/latent_dist.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""Container for latent space posterior."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDistribution:
|
||||||
|
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||||
|
"""Initialize latent distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||||
|
"""
|
||||||
|
assert mean.shape == logvar.shape
|
||||||
|
self.mean = mean
|
||||||
|
self.logvar = logvar
|
||||||
|
|
||||||
|
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||||
|
if temperature == 0.0:
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||||
|
else:
|
||||||
|
assert noise.device == self.mean.device
|
||||||
|
noise = noise.to(self.mean.dtype)
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||||
|
|
||||||
|
# Just Gaussian sample with no scaling of variance.
|
||||||
|
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -7,7 +8,7 @@ from einops import rearrange
|
|||||||
|
|
||||||
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
from .latent_dist import LatentDistribution
|
||||||
|
|
||||||
def cast_tuple(t, length=1):
|
def cast_tuple(t, length=1):
|
||||||
return t if isinstance(t, tuple) else ((t,) * length)
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
@ -135,9 +136,13 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
pad_back = context_size - pad_front
|
pad_back = context_size - pad_front
|
||||||
|
|
||||||
# Apply padding.
|
# Apply padding.
|
||||||
assert self.padding_mode == "replicate" # DEBUG
|
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
|
if self.context_parallel:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
|
else:
|
||||||
|
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||||
|
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
@ -221,8 +226,10 @@ class ResBlock(nn.Module):
|
|||||||
*,
|
*,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
attn_block: Optional[nn.Module] = None,
|
attn_block: Optional[nn.Module] = None,
|
||||||
padding_mode: str = "replicate",
|
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
|
prune_bottleneck: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -233,22 +240,22 @@ class ResBlock(nn.Module):
|
|||||||
nn.SiLU(inplace=True),
|
nn.SiLU(inplace=True),
|
||||||
ContextParallelConv3d(
|
ContextParallelConv3d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=channels,
|
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
kernel_size=(3, 3, 3),
|
kernel_size=(3, 3, 3),
|
||||||
stride=(1, 1, 1),
|
stride=(1, 1, 1),
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
bias=True,
|
bias=bias,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
),
|
),
|
||||||
norm_fn(channels, affine=affine),
|
norm_fn(channels, affine=affine),
|
||||||
nn.SiLU(inplace=True),
|
nn.SiLU(inplace=True),
|
||||||
ContextParallelConv3d(
|
ContextParallelConv3d(
|
||||||
in_channels=channels,
|
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
kernel_size=(3, 3, 3),
|
kernel_size=(3, 3, 3),
|
||||||
stride=(1, 1, 1),
|
stride=(1, 1, 1),
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
bias=True,
|
bias=bias,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -357,9 +364,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if q.size(0) <= chunk_size:
|
if q.size(0) <= chunk_size:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(q, k, v, **attn_kwargs) # [B, num_heads, t, head_dim]
|
||||||
q, k, v, **attn_kwargs
|
|
||||||
) # [B, num_heads, t, head_dim]
|
|
||||||
else:
|
else:
|
||||||
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
|
# Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
|
||||||
# Chunks of 2**16 and up cause an error.
|
# Chunks of 2**16 and up cause an error.
|
||||||
@ -421,9 +426,7 @@ class CausalUpsampleBlock(nn.Module):
|
|||||||
out_channels * temporal_expansion * (spatial_expansion**2),
|
out_channels * temporal_expansion * (spatial_expansion**2),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.d2st = DepthToSpaceTime(
|
self.d2st = DepthToSpaceTime(temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion)
|
||||||
temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
@ -432,60 +435,9 @@ class CausalUpsampleBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||||
#attn_block = AttentionBlock(channels) if has_attention else None
|
attn_block = AttentionBlock(channels) if has_attention else None
|
||||||
|
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||||
return ResBlock(
|
|
||||||
channels, affine=True, attn_block=None, **block_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampleBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
num_res_blocks,
|
|
||||||
*,
|
|
||||||
temporal_reduction=2,
|
|
||||||
spatial_reduction=2,
|
|
||||||
**block_kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Downsample block for the VAE encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_channels: Number of input channels.
|
|
||||||
out_channels: Number of output channels.
|
|
||||||
num_res_blocks: Number of residual blocks.
|
|
||||||
temporal_reduction: Temporal reduction factor.
|
|
||||||
spatial_reduction: Spatial reduction factor.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
# Change the channel count in the strided convolution.
|
|
||||||
# This lets the ResBlock have uniform channel count,
|
|
||||||
# as in ConvNeXt.
|
|
||||||
assert in_channels != out_channels
|
|
||||||
layers.append(
|
|
||||||
ContextParallelConv3d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
|
||||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
|
||||||
padding_mode="replicate",
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(num_res_blocks):
|
|
||||||
layers.append(block_fn(out_channels, **block_kwargs))
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
|
|
||||||
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
|
||||||
@ -568,14 +520,13 @@ class Decoder(nn.Module):
|
|||||||
assert len(num_res_blocks) == self.num_up_blocks + 2
|
assert len(num_res_blocks) == self.num_up_blocks + 2
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
|
new_block_fn = partial(block_fn, padding_mode="replicate")
|
||||||
|
|
||||||
first_block = [
|
first_block = [nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))] # Input layer.
|
||||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
|
||||||
] # Input layer.
|
|
||||||
# First set of blocks preserve channel count.
|
# First set of blocks preserve channel count.
|
||||||
for _ in range(num_res_blocks[-1]):
|
for _ in range(num_res_blocks[-1]):
|
||||||
first_block.append(
|
first_block.append(
|
||||||
block_fn(
|
new_block_fn(
|
||||||
ch[-1],
|
ch[-1],
|
||||||
has_attention=has_attention[-1],
|
has_attention=has_attention[-1],
|
||||||
causal=causal,
|
causal=causal,
|
||||||
@ -598,6 +549,7 @@ class Decoder(nn.Module):
|
|||||||
temporal_expansion=temporal_expansions[-i - 1],
|
temporal_expansion=temporal_expansions[-i - 1],
|
||||||
spatial_expansion=spatial_expansions[-i - 1],
|
spatial_expansion=spatial_expansions[-i - 1],
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
padding_mode="replicate",
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
@ -607,11 +559,7 @@ class Decoder(nn.Module):
|
|||||||
# Last block. Preserve channel count.
|
# Last block. Preserve channel count.
|
||||||
last_block = []
|
last_block = []
|
||||||
for _ in range(num_res_blocks[0]):
|
for _ in range(num_res_blocks[0]):
|
||||||
last_block.append(
|
last_block.append(new_block_fn(ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs))
|
||||||
block_fn(
|
|
||||||
ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
blocks.append(nn.Sequential(*last_block))
|
blocks.append(nn.Sequential(*last_block))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
@ -634,9 +582,7 @@ class Decoder(nn.Module):
|
|||||||
if self.output_nonlinearity == "silu":
|
if self.output_nonlinearity == "silu":
|
||||||
x = F.silu(x, inplace=not self.training)
|
x = F.silu(x, inplace=not self.training)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert not self.output_nonlinearity # StyleGAN3 omits the to-RGB nonlinearity.
|
||||||
not self.output_nonlinearity
|
|
||||||
) # StyleGAN3 omits the to-RGB nonlinearity.
|
|
||||||
|
|
||||||
return self.output_proj(x).contiguous()
|
return self.output_proj(x).contiguous()
|
||||||
|
|
||||||
@ -678,9 +624,7 @@ def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The blended tensor.
|
torch.Tensor: The blended tensor.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert a.shape == b.shape, f"Tensors must have the same shape, got {a.shape} and {b.shape}"
|
||||||
a.shape == b.shape
|
|
||||||
), f"Tensors must have the same shape, got {a.shape} and {b.shape}"
|
|
||||||
steps = a.size(axis)
|
steps = a.size(axis)
|
||||||
|
|
||||||
# Create a weight tensor that linearly interpolates from 0 to 1
|
# Create a weight tensor that linearly interpolates from 0 to 1
|
||||||
@ -800,3 +744,270 @@ def apply_tiled(
|
|||||||
return blend_vertical(top, bottom, out_overlap)
|
return blend_vertical(top, bottom, out_overlap)
|
||||||
|
|
||||||
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
|
raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampleBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_res_blocks,
|
||||||
|
*,
|
||||||
|
temporal_reduction=2,
|
||||||
|
spatial_reduction=2,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Downsample block for the VAE encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels.
|
||||||
|
out_channels: Number of output channels.
|
||||||
|
num_res_blocks: Number of residual blocks.
|
||||||
|
temporal_reduction: Temporal reduction factor.
|
||||||
|
spatial_reduction: Spatial reduction factor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
assert in_channels != out_channels
|
||||||
|
layers.append(
|
||||||
|
ContextParallelConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||||
|
# First layer in each block always uses replicate padding
|
||||||
|
padding_mode="replicate",
|
||||||
|
bias=block_kwargs["bias"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers.append(block_fn(out_channels, **block_kwargs))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels: int,
|
||||||
|
base_channels: int,
|
||||||
|
channel_multipliers: List[int],
|
||||||
|
num_res_blocks: List[int],
|
||||||
|
latent_dim: int,
|
||||||
|
temporal_reductions: List[int],
|
||||||
|
spatial_reductions: List[int],
|
||||||
|
prune_bottlenecks: List[bool],
|
||||||
|
has_attentions: List[bool],
|
||||||
|
affine: bool = True,
|
||||||
|
bias: bool = True,
|
||||||
|
input_is_conv_1x1: bool = False,
|
||||||
|
padding_mode: str,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.temporal_reductions = temporal_reductions
|
||||||
|
self.spatial_reductions = spatial_reductions
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multipliers = channel_multipliers
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
ch = [mult * base_channels for mult in channel_multipliers]
|
||||||
|
num_down_blocks = len(ch) - 1
|
||||||
|
assert len(num_res_blocks) == num_down_blocks + 2
|
||||||
|
|
||||||
|
layers = (
|
||||||
|
[nn.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||||
|
if not input_is_conv_1x1
|
||||||
|
else [Conv1x1(in_channels, ch[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||||
|
assert len(has_attentions) == num_down_blocks + 2
|
||||||
|
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||||
|
|
||||||
|
for _ in range(num_res_blocks[0]):
|
||||||
|
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||||
|
prune_bottlenecks = prune_bottlenecks[1:]
|
||||||
|
has_attentions = has_attentions[1:]
|
||||||
|
|
||||||
|
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
layer = DownsampleBlock(
|
||||||
|
ch[i],
|
||||||
|
ch[i + 1],
|
||||||
|
num_res_blocks=num_res_blocks[i + 1],
|
||||||
|
temporal_reduction=temporal_reductions[i],
|
||||||
|
spatial_reduction=spatial_reductions[i],
|
||||||
|
prune_bottleneck=prune_bottlenecks[i],
|
||||||
|
has_attention=has_attentions[i],
|
||||||
|
affine=affine,
|
||||||
|
bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
# Additional blocks.
|
||||||
|
for _ in range(num_res_blocks[-1]):
|
||||||
|
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# Output layers.
|
||||||
|
self.output_norm = norm_fn(ch[-1])
|
||||||
|
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temporal_downsample(self):
|
||||||
|
return math.prod(self.temporal_reductions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spatial_downsample(self):
|
||||||
|
return math.prod(self.spatial_reductions)
|
||||||
|
|
||||||
|
def forward(self, x) -> LatentDistribution:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||||
|
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||||
|
logvar: Shape: [B, latent_dim, t, h, w].
|
||||||
|
"""
|
||||||
|
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||||
|
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
x = self.output_norm(x)
|
||||||
|
x = F.silu(x, inplace=True)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
|
||||||
|
means, logvar = torch.chunk(x, 2, dim=1)
|
||||||
|
|
||||||
|
assert means.ndim == 5
|
||||||
|
assert logvar.shape == means.shape
|
||||||
|
assert means.size(1) == self.latent_dim
|
||||||
|
|
||||||
|
return LatentDistribution(means, logvar)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_decoded_frames(samples):
|
||||||
|
samples = samples.float()
|
||||||
|
samples = (samples + 1.0) / 2.0
|
||||||
|
samples.clamp_(0.0, 1.0)
|
||||||
|
frames = rearrange(samples, "b c t h w -> b t h w c")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def decode_latents_tiled_full(
|
||||||
|
decoder,
|
||||||
|
z,
|
||||||
|
*,
|
||||||
|
tile_sample_min_height: int = 240,
|
||||||
|
tile_sample_min_width: int = 424,
|
||||||
|
tile_overlap_factor_height: float = 0.1666,
|
||||||
|
tile_overlap_factor_width: float = 0.2,
|
||||||
|
auto_tile_size: bool = True,
|
||||||
|
frame_batch_size: int = 6,
|
||||||
|
):
|
||||||
|
B, C, T, H, W = z.shape
|
||||||
|
assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}"
|
||||||
|
|
||||||
|
tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
|
||||||
|
tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
|
||||||
|
|
||||||
|
tile_latent_min_height = int(tile_sample_min_height / 8)
|
||||||
|
tile_latent_min_width = int(tile_sample_min_width / 8)
|
||||||
|
|
||||||
|
def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||||
|
y / blend_extent
|
||||||
|
)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
|
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||||
|
x / blend_extent
|
||||||
|
)
|
||||||
|
return b
|
||||||
|
|
||||||
|
overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height))
|
||||||
|
overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width))
|
||||||
|
blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height)
|
||||||
|
blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width)
|
||||||
|
row_limit_height = tile_sample_min_height - blend_extent_height
|
||||||
|
row_limit_width = tile_sample_min_width - blend_extent_width
|
||||||
|
|
||||||
|
# Split z into overlapping tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
pbar = tqdm(
|
||||||
|
desc="Decoding latent tiles",
|
||||||
|
total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)),
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for i in range(0, H, overlap_height):
|
||||||
|
row = []
|
||||||
|
for j in range(0, W, overlap_width):
|
||||||
|
temporal = []
|
||||||
|
for k in range(T // frame_batch_size):
|
||||||
|
remaining_frames = T % frame_batch_size
|
||||||
|
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||||
|
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||||
|
tile = z[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
start_frame:end_frame,
|
||||||
|
i : i + tile_latent_min_height,
|
||||||
|
j : j + tile_latent_min_width,
|
||||||
|
]
|
||||||
|
tile = decoder(tile)
|
||||||
|
temporal.append(tile)
|
||||||
|
pbar.update(1)
|
||||||
|
row.append(torch.cat(temporal, dim=2))
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||||
|
if j > 0:
|
||||||
|
tile = blend_h(row[j - 1], tile, blend_extent_width)
|
||||||
|
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=4))
|
||||||
|
|
||||||
|
return normalize_decoded_frames(torch.cat(result_rows, dim=3))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def decode_latents_tiled_spatial(
|
||||||
|
decoder,
|
||||||
|
z,
|
||||||
|
*,
|
||||||
|
num_tiles_w: int,
|
||||||
|
num_tiles_h: int,
|
||||||
|
overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
|
||||||
|
# Use a factor of 2 times the latent downsample factor.
|
||||||
|
min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
|
||||||
|
):
|
||||||
|
decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size)
|
||||||
|
assert decoded is not None, f"Failed to decode latents with tiled spatial method"
|
||||||
|
return normalize_decoded_frames(decoded)
|
||||||
62
mochi_preview/vae/vae_stats.py
Normal file
62
mochi_preview/vae/vae_stats.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# Channel-wise mean and standard deviation of VAE encoder latents
|
||||||
|
STATS = {
|
||||||
|
"mean": torch.Tensor([
|
||||||
|
-0.06730895953510081,
|
||||||
|
-0.038011381506090416,
|
||||||
|
-0.07477820912866141,
|
||||||
|
-0.05565264470995561,
|
||||||
|
0.012767231469026969,
|
||||||
|
-0.04703542746246419,
|
||||||
|
0.043896967884726704,
|
||||||
|
-0.09346305707025976,
|
||||||
|
-0.09918314763016893,
|
||||||
|
-0.008729793427399178,
|
||||||
|
-0.011931556316503654,
|
||||||
|
-0.0321993391887285,
|
||||||
|
]),
|
||||||
|
"std": torch.Tensor([
|
||||||
|
0.9263795028493863,
|
||||||
|
0.9248894543193766,
|
||||||
|
0.9393059390890617,
|
||||||
|
0.959253732819592,
|
||||||
|
0.8244560132752793,
|
||||||
|
0.917259975397747,
|
||||||
|
0.9294154431013696,
|
||||||
|
1.3720942357788521,
|
||||||
|
0.881393668867029,
|
||||||
|
0.9168315692124348,
|
||||||
|
0.9185249279345552,
|
||||||
|
0.9274757570805041,
|
||||||
|
]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Unnormalize latents output by Mochi's DiT to be compatible with VAE.
|
||||||
|
Run this on sampled latents before calling the VAE decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: [B, C_z, T_z, H_z, W_z], float
|
||||||
|
"""
|
||||||
|
mean = STATS["mean"][:, None, None, None]
|
||||||
|
std = STATS["std"][:, None, None, None]
|
||||||
|
|
||||||
|
assert dit_outputs.ndim == 5
|
||||||
|
assert dit_outputs.size(1) == mean.size(0) == std.size(0)
|
||||||
|
return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def vae_latents_to_dit_latents(vae_latents: torch.Tensor):
|
||||||
|
"""Normalize latents output by the VAE encoder to be compatible with Mochi's DiT.
|
||||||
|
E.g, for fine-tuning or video-to-video.
|
||||||
|
"""
|
||||||
|
mean = STATS["mean"][:, None, None, None]
|
||||||
|
std = STATS["std"][:, None, None, None]
|
||||||
|
|
||||||
|
assert vae_latents.ndim == 5
|
||||||
|
assert vae_latents.size(1) == mean.size(0) == std.size(0)
|
||||||
|
return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)
|
||||||
186
nodes.py
186
nodes.py
@ -13,7 +13,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
from .mochi_preview.t2v_synth_mochi import T2VSynthMochiModel
|
||||||
from .mochi_preview.vae.model import Decoder
|
from .mochi_preview.vae.model import Decoder, Encoder, add_fourier_features
|
||||||
|
from .mochi_preview.vae.vae_stats import vae_latents_to_dit_latents, dit_latents_to_vae_latents
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
try:
|
try:
|
||||||
@ -140,7 +141,6 @@ class DownloadAndLoadMochiModel:
|
|||||||
num_res_blocks=[3, 3, 4, 6, 3],
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
latent_dim=12,
|
latent_dim=12,
|
||||||
has_attention=[False, False, False, False, False],
|
has_attention=[False, False, False, False, False],
|
||||||
padding_mode="replicate",
|
|
||||||
output_norm=False,
|
output_norm=False,
|
||||||
nonlinearity="silu",
|
nonlinearity="silu",
|
||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
@ -269,7 +269,6 @@ class MochiVAELoader:
|
|||||||
num_res_blocks=[3, 3, 4, 6, 3],
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
latent_dim=12,
|
latent_dim=12,
|
||||||
has_attention=[False, False, False, False, False],
|
has_attention=[False, False, False, False, False],
|
||||||
padding_mode="replicate",
|
|
||||||
output_norm=False,
|
output_norm=False,
|
||||||
nonlinearity="silu",
|
nonlinearity="silu",
|
||||||
output_nonlinearity="silu",
|
output_nonlinearity="silu",
|
||||||
@ -292,6 +291,74 @@ class MochiVAELoader:
|
|||||||
|
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
|
class MochiVAEEncoderLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "The name of the checkpoint (vae) to load."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"torch_compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
|
"precision": (["fp16", "fp32", "bf16"], {"default": "bf16"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MOCHIVAE",)
|
||||||
|
RETURN_NAMES = ("mochi_vae", )
|
||||||
|
FUNCTION = "loadmodel"
|
||||||
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
|
def loadmodel(self, model_name, torch_compile_args=None, precision="bf16"):
|
||||||
|
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
prune_bottlenecks=[False, False, False, False, False],
|
||||||
|
has_attentions=[False, True, True, True, True],
|
||||||
|
affine=True,
|
||||||
|
bias=True,
|
||||||
|
input_is_conv_1x1=True,
|
||||||
|
padding_mode="replicate"
|
||||||
|
)
|
||||||
|
|
||||||
|
vae_path = folder_paths.get_full_path_or_raise("vae", model_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Create VAE encoder
|
||||||
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
|
encoder = Encoder(
|
||||||
|
in_channels=15,
|
||||||
|
base_channels=64,
|
||||||
|
channel_multipliers=[1, 2, 4, 6],
|
||||||
|
num_res_blocks=[3, 3, 4, 6, 3],
|
||||||
|
latent_dim=12,
|
||||||
|
temporal_reductions=[1, 2, 3],
|
||||||
|
spatial_reductions=[2, 2, 2],
|
||||||
|
dtype = dtype,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_sd = load_torch_file(vae_path)
|
||||||
|
if is_accelerate_available:
|
||||||
|
for name, param in encoder.named_parameters():
|
||||||
|
set_module_tensor_to_device(encoder, name, dtype=dtype, device=offload_device, value=encoder_sd[name])
|
||||||
|
else:
|
||||||
|
encoder.load_state_dict(encoder_sd, strict=True)
|
||||||
|
encoder.to(dtype).to(offload_device)
|
||||||
|
encoder.eval()
|
||||||
|
del encoder_sd
|
||||||
|
|
||||||
|
if torch_compile_args is not None:
|
||||||
|
encoder.to(device)
|
||||||
|
encoder = torch.compile(encoder, fullgraph=torch_compile_args["fullgraph"], mode=torch_compile_args["mode"], dynamic=False, backend=torch_compile_args["backend"])
|
||||||
|
|
||||||
|
return (encoder,)
|
||||||
|
|
||||||
class MochiTextEncode:
|
class MochiTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -365,6 +432,7 @@ class MochiSampler:
|
|||||||
"optional": {
|
"optional": {
|
||||||
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
||||||
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
|
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
|
||||||
|
"samples": ("LATENT", ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -373,7 +441,7 @@ class MochiSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None):
|
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
if opt_sigmas is not None:
|
if opt_sigmas is not None:
|
||||||
@ -413,6 +481,7 @@ class MochiSampler:
|
|||||||
"positive_embeds": positive,
|
"positive_embeds": positive,
|
||||||
"negative_embeds": negative,
|
"negative_embeds": negative,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
|
"samples": samples["samples"] if samples is not None else None,
|
||||||
}
|
}
|
||||||
latents = model.run(args)
|
latents = model.run(args)
|
||||||
|
|
||||||
@ -447,6 +516,7 @@ class MochiDecode:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
|
samples = dit_latents_to_vae_latents(samples)
|
||||||
samples = samples.to(vae.dtype).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
B, C, T, H, W = samples.shape
|
B, C, T, H, W = samples.shape
|
||||||
@ -574,6 +644,7 @@ class MochiDecodeSpatialTiling:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
intermediate_device = mm.intermediate_device()
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
|
samples = dit_latents_to_vae_latents(samples)
|
||||||
samples = samples.to(vae.dtype).to(device)
|
samples = samples.to(vae.dtype).to(device)
|
||||||
|
|
||||||
B, C, T, H, W = samples.shape
|
B, C, T, H, W = samples.shape
|
||||||
@ -616,6 +687,101 @@ class MochiDecodeSpatialTiling:
|
|||||||
|
|
||||||
return (frames,)
|
return (frames,)
|
||||||
|
|
||||||
|
class MochiImageEncode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"encoder": ("MOCHIVAE",),
|
||||||
|
"images": ("IMAGE", ),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
RETURN_NAMES = ("samples",)
|
||||||
|
FUNCTION = "decode"
|
||||||
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
|
def decode(self, encoder, images):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
intermediate_device = mm.intermediate_device()
|
||||||
|
|
||||||
|
B, H, W, C = images.shape
|
||||||
|
|
||||||
|
images = images.unsqueeze(0) * 2 - 1
|
||||||
|
images = rearrange(images, "t b h w c -> t c b h w")
|
||||||
|
images = images.to(encoder.dtype).to(device)
|
||||||
|
print(images.shape)
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
print("images before encoding", images.shape)
|
||||||
|
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
||||||
|
video = add_fourier_features(images)
|
||||||
|
latents = encoder(video).sample()
|
||||||
|
latents = vae_latents_to_dit_latents(latents)
|
||||||
|
print("encoder output",latents.shape)
|
||||||
|
|
||||||
|
return ({"samples": latents},)
|
||||||
|
|
||||||
|
class MochiLatentPreview:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"samples": ("LATENT",),
|
||||||
|
# "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
# "min_val": ("FLOAT", {"default": -0.15, "min": -1.0, "max": 0.0, "step": 0.001}),
|
||||||
|
# "max_val": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", )
|
||||||
|
RETURN_NAMES = ("images", )
|
||||||
|
FUNCTION = "sample"
|
||||||
|
CATEGORY = "PyramidFlowWrapper"
|
||||||
|
|
||||||
|
def sample(self, samples):#, seed, min_val, max_val):
|
||||||
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
latents = samples["samples"].clone()
|
||||||
|
print("in sample", latents.shape)
|
||||||
|
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
|
latent_rgb_factors = [[0.1236769792512748, 0.11775175335219157, -0.17700629766423637], [-0.08504104329270078, 0.026605813147523694, -0.006843165704926019], [-0.17093308616366876, 0.027991854696200386, 0.14179146288816308], [-0.17179555328757623, 0.09844317368603078, 0.14470997015982784], [-0.16975067171668484, -0.10739852629856643, -0.1894254942909962], [-0.19315259266769888, -0.011029760569485209, -0.08519702054654255], [-0.08399895091432583, -0.0964246452052032, -0.033622359523655665], [0.08148916330842498, 0.027500645903400067, -0.06593099749891196], [0.0456603103902293, -0.17844808072462398, 0.04204775167149785], [0.001751626383204502, -0.030567890189647867, -0.022078082809772193], [0.05110631095056278, -0.0709677393548804, 0.08963683539504264], [0.010515800868829, -0.18382052841762514, -0.08554553339721907]]
|
||||||
|
|
||||||
|
# import random
|
||||||
|
# random.seed(seed)
|
||||||
|
# latent_rgb_factors = [[random.uniform(min_val, max_val) for _ in range(3)] for _ in range(12)]
|
||||||
|
# out_factors = latent_rgb_factors
|
||||||
|
# print(latent_rgb_factors)
|
||||||
|
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [0,0,0]
|
||||||
|
|
||||||
|
latent_rgb_factors = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)
|
||||||
|
latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
print("latent_rgb_factors", latent_rgb_factors.shape)
|
||||||
|
|
||||||
|
latent_images = []
|
||||||
|
for t in range(latents.shape[2]):
|
||||||
|
latent = latents[:, :, t, :, :]
|
||||||
|
latent = latent[0].permute(1, 2, 0)
|
||||||
|
latent_image = torch.nn.functional.linear(
|
||||||
|
latent,
|
||||||
|
latent_rgb_factors,
|
||||||
|
bias=latent_rgb_factors_bias
|
||||||
|
)
|
||||||
|
latent_images.append(latent_image)
|
||||||
|
latent_images = torch.stack(latent_images, dim=0)
|
||||||
|
print("latent_images", latent_images.shape)
|
||||||
|
latent_images_min = latent_images.min()
|
||||||
|
latent_images_max = latent_images.max()
|
||||||
|
latent_images = (latent_images - latent_images_min) / (latent_images_max - latent_images_min)
|
||||||
|
|
||||||
|
return (latent_images.float().cpu(),)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
"DownloadAndLoadMochiModel": DownloadAndLoadMochiModel,
|
||||||
@ -624,8 +790,11 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"MochiTextEncode": MochiTextEncode,
|
"MochiTextEncode": MochiTextEncode,
|
||||||
"MochiModelLoader": MochiModelLoader,
|
"MochiModelLoader": MochiModelLoader,
|
||||||
"MochiVAELoader": MochiVAELoader,
|
"MochiVAELoader": MochiVAELoader,
|
||||||
|
"MochiVAEEncoderLoader": MochiVAEEncoderLoader,
|
||||||
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
|
"MochiDecodeSpatialTiling": MochiDecodeSpatialTiling,
|
||||||
"MochiTorchCompileSettings": MochiTorchCompileSettings
|
"MochiTorchCompileSettings": MochiTorchCompileSettings,
|
||||||
|
"MochiImageEncode": MochiImageEncode,
|
||||||
|
"MochiLatentPreview": MochiLatentPreview
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||||
@ -633,7 +802,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"MochiDecode": "Mochi Decode",
|
"MochiDecode": "Mochi Decode",
|
||||||
"MochiTextEncode": "Mochi TextEncode",
|
"MochiTextEncode": "Mochi TextEncode",
|
||||||
"MochiModelLoader": "Mochi Model Loader",
|
"MochiModelLoader": "Mochi Model Loader",
|
||||||
"MochiVAELoader": "Mochi VAE Loader",
|
"MochiVAELoader": "Mochi VAE Decoder Loader",
|
||||||
|
"MochiVAEEncoderLoader": "Mochi VAE Encoder Loader",
|
||||||
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
|
"MochiDecodeSpatialTiling": "Mochi VAE Decode Spatial Tiling",
|
||||||
"MochiTorchCompileSettings": "Mochi Torch Compile Settings"
|
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
|
||||||
|
"MochiImageEncode": "Mochi Image Encode",
|
||||||
|
"MochiLatentPreview": "Mochi Latent Preview"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user