mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 21:04:25 +08:00
1641 lines
56 KiB
Python
1641 lines
56 KiB
Python
# Adapted from Open-Sora-Plan
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
|
|
# --------------------------------------------------------
|
|
import glob
|
|
import os
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from diffusers import ConfigMixin, ModelMixin
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.utils import logging
|
|
from einops import rearrange
|
|
from torch import nn
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32):
|
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
|
|
def tensor_to_video(x):
|
|
x = x.detach().cpu()
|
|
x = torch.clamp(x, -1, 1)
|
|
x = (x + 1) / 2
|
|
x = x.permute(1, 0, 2, 3).float().numpy() # c t h w ->
|
|
x = (255 * x).astype(np.uint8)
|
|
return x
|
|
|
|
|
|
def nonlinearity(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class DiagonalGaussianDistribution(object):
|
|
def __init__(self, parameters, deterministic=False):
|
|
self.parameters = parameters
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
self.deterministic = deterministic
|
|
self.std = torch.exp(0.5 * self.logvar)
|
|
self.var = torch.exp(self.logvar)
|
|
if self.deterministic:
|
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
|
|
|
def sample(self):
|
|
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
|
return x
|
|
|
|
def kl(self, other=None):
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
else:
|
|
if other is None:
|
|
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
|
else:
|
|
return 0.5 * torch.sum(
|
|
torch.pow(self.mean - other.mean, 2) / other.var
|
|
+ self.var / other.var
|
|
- 1.0
|
|
- self.logvar
|
|
+ other.logvar,
|
|
dim=[1, 2, 3],
|
|
)
|
|
|
|
def nll(self, sample, dims=[1, 2, 3]):
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
logtwopi = np.log(2.0 * np.pi)
|
|
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
|
|
|
def mode(self):
|
|
return self.mean
|
|
|
|
|
|
def resolve_str_to_obj(str_val, append=True):
|
|
return globals()[str_val]
|
|
|
|
|
|
class VideoBaseAE_PL(ModelMixin, ConfigMixin):
|
|
config_name = "config.json"
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def encode(self, x: torch.Tensor, *args, **kwargs):
|
|
pass
|
|
|
|
def decode(self, encoding: torch.Tensor, *args, **kwargs):
|
|
pass
|
|
|
|
@property
|
|
def num_training_steps(self) -> int:
|
|
"""Total training steps inferred from datamodule and devices."""
|
|
if self.trainer.max_steps:
|
|
return self.trainer.max_steps
|
|
|
|
limit_batches = self.trainer.limit_train_batches
|
|
batches = len(self.train_dataloader())
|
|
batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
|
|
|
|
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
|
|
if self.trainer.tpu_cores:
|
|
num_devices = max(num_devices, self.trainer.tpu_cores)
|
|
|
|
effective_accum = self.trainer.accumulate_grad_batches * num_devices
|
|
return (batches // effective_accum) * self.trainer.max_epochs
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
|
ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt"))
|
|
if ckpt_files:
|
|
# Adapt to PyTorch Lightning
|
|
last_ckpt_file = ckpt_files[-1]
|
|
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
|
model = cls.from_config(config_file)
|
|
print("init from {}".format(last_ckpt_file))
|
|
model.init_from_ckpt(last_ckpt_file)
|
|
return model
|
|
else:
|
|
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
z_channels: int,
|
|
hidden_size: int,
|
|
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
|
|
attn_resolutions: Tuple[int] = (16,),
|
|
conv_in: str = "Conv2d",
|
|
conv_out: str = "CasualConv3d",
|
|
attention: str = "AttnBlock",
|
|
resnet_blocks: Tuple[str] = (
|
|
"ResnetBlock2D",
|
|
"ResnetBlock2D",
|
|
"ResnetBlock2D",
|
|
"ResnetBlock3D",
|
|
),
|
|
spatial_downsample: Tuple[str] = (
|
|
"Downsample",
|
|
"Downsample",
|
|
"Downsample",
|
|
"",
|
|
),
|
|
temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""),
|
|
mid_resnet: str = "ResnetBlock3D",
|
|
dropout: float = 0.0,
|
|
resolution: int = 256,
|
|
num_res_blocks: int = 2,
|
|
double_z: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks)
|
|
# ---- Config ----
|
|
self.num_resolutions = len(hidden_size_mult)
|
|
self.resolution = resolution
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
# ---- In ----
|
|
self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1)
|
|
|
|
# ---- Downsample ----
|
|
curr_res = resolution
|
|
in_ch_mult = (1,) + tuple(hidden_size_mult)
|
|
self.in_ch_mult = in_ch_mult
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = hidden_size * in_ch_mult[i_level]
|
|
block_out = hidden_size * hidden_size_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(
|
|
resolve_str_to_obj(resnet_blocks[i_level])(
|
|
in_channels=block_in,
|
|
out_channels=block_out,
|
|
dropout=dropout,
|
|
)
|
|
)
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(resolve_str_to_obj(attention)(block_in))
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if spatial_downsample[i_level]:
|
|
down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in)
|
|
curr_res = curr_res // 2
|
|
if temporal_downsample[i_level]:
|
|
down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in)
|
|
self.down.append(down)
|
|
|
|
# ---- Mid ----
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
dropout=dropout,
|
|
)
|
|
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
|
|
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
dropout=dropout,
|
|
)
|
|
# ---- Out ----
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = resolve_str_to_obj(conv_out)(
|
|
block_in,
|
|
2 * z_channels if double_z else z_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
|
|
def forward(self, x):
|
|
hs = [self.conv_in(x)]
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](hs[-1])
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
hs.append(h)
|
|
if hasattr(self.down[i_level], "downsample"):
|
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
|
if hasattr(self.down[i_level], "time_downsample"):
|
|
hs_down = self.down[i_level].time_downsample(hs[-1])
|
|
hs.append(hs_down)
|
|
|
|
h = self.mid.block_1(h)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h)
|
|
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
z_channels: int,
|
|
hidden_size: int,
|
|
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
|
|
attn_resolutions: Tuple[int] = (16,),
|
|
conv_in: str = "Conv2d",
|
|
conv_out: str = "CasualConv3d",
|
|
attention: str = "AttnBlock",
|
|
resnet_blocks: Tuple[str] = (
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
),
|
|
spatial_upsample: Tuple[str] = (
|
|
"",
|
|
"SpatialUpsample2x",
|
|
"SpatialUpsample2x",
|
|
"SpatialUpsample2x",
|
|
),
|
|
temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"),
|
|
mid_resnet: str = "ResnetBlock3D",
|
|
dropout: float = 0.0,
|
|
resolution: int = 256,
|
|
num_res_blocks: int = 2,
|
|
):
|
|
super().__init__()
|
|
# ---- Config ----
|
|
self.num_resolutions = len(hidden_size_mult)
|
|
self.resolution = resolution
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
# ---- In ----
|
|
block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
|
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1)
|
|
|
|
# ---- Mid ----
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
dropout=dropout,
|
|
)
|
|
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
|
|
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
dropout=dropout,
|
|
)
|
|
|
|
# ---- Upsample ----
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = hidden_size * hidden_size_mult[i_level]
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
block.append(
|
|
resolve_str_to_obj(resnet_blocks[i_level])(
|
|
in_channels=block_in,
|
|
out_channels=block_out,
|
|
dropout=dropout,
|
|
)
|
|
)
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(resolve_str_to_obj(attention)(block_in))
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if spatial_upsample[i_level]:
|
|
up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in)
|
|
curr_res = curr_res * 2
|
|
if temporal_upsample[i_level]:
|
|
up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in)
|
|
self.up.insert(0, up)
|
|
|
|
# ---- Out ----
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1)
|
|
|
|
def forward(self, z):
|
|
h = self.conv_in(z)
|
|
h = self.mid.block_1(h)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h)
|
|
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
h = self.up[i_level].block[i_block](h)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h)
|
|
if hasattr(self.up[i_level], "upsample"):
|
|
h = self.up[i_level].upsample(h)
|
|
if hasattr(self.up[i_level], "time_upsample"):
|
|
h = self.up[i_level].time_upsample(h)
|
|
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class CausalVAEModel(VideoBaseAE_PL):
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
lr: float = 1e-5,
|
|
hidden_size: int = 128,
|
|
z_channels: int = 4,
|
|
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
|
|
attn_resolutions: Tuple[int] = [],
|
|
dropout: float = 0.0,
|
|
resolution: int = 256,
|
|
double_z: bool = True,
|
|
embed_dim: int = 4,
|
|
num_res_blocks: int = 2,
|
|
loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
|
|
loss_params: dict = {
|
|
"kl_weight": 0.000001,
|
|
"logvar_init": 0.0,
|
|
"disc_start": 2001,
|
|
"disc_weight": 0.5,
|
|
},
|
|
q_conv: str = "CausalConv3d",
|
|
encoder_conv_in: str = "CausalConv3d",
|
|
encoder_conv_out: str = "CausalConv3d",
|
|
encoder_attention: str = "AttnBlock3D",
|
|
encoder_resnet_blocks: Tuple[str] = (
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
),
|
|
encoder_spatial_downsample: Tuple[str] = (
|
|
"SpatialDownsample2x",
|
|
"SpatialDownsample2x",
|
|
"SpatialDownsample2x",
|
|
"",
|
|
),
|
|
encoder_temporal_downsample: Tuple[str] = (
|
|
"",
|
|
"TimeDownsample2x",
|
|
"TimeDownsample2x",
|
|
"",
|
|
),
|
|
encoder_mid_resnet: str = "ResnetBlock3D",
|
|
decoder_conv_in: str = "CausalConv3d",
|
|
decoder_conv_out: str = "CausalConv3d",
|
|
decoder_attention: str = "AttnBlock3D",
|
|
decoder_resnet_blocks: Tuple[str] = (
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
"ResnetBlock3D",
|
|
),
|
|
decoder_spatial_upsample: Tuple[str] = (
|
|
"",
|
|
"SpatialUpsample2x",
|
|
"SpatialUpsample2x",
|
|
"SpatialUpsample2x",
|
|
),
|
|
decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"),
|
|
decoder_mid_resnet: str = "ResnetBlock3D",
|
|
) -> None:
|
|
super().__init__()
|
|
self.tile_sample_min_size = 256
|
|
self.tile_sample_min_size_t = 65
|
|
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
|
|
t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
|
|
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1
|
|
self.tile_overlap_factor = 0.25
|
|
self.use_tiling = False
|
|
|
|
self.learning_rate = lr
|
|
self.lr_g_factor = 1.0
|
|
|
|
self.encoder = Encoder(
|
|
z_channels=z_channels,
|
|
hidden_size=hidden_size,
|
|
hidden_size_mult=hidden_size_mult,
|
|
attn_resolutions=attn_resolutions,
|
|
conv_in=encoder_conv_in,
|
|
conv_out=encoder_conv_out,
|
|
attention=encoder_attention,
|
|
resnet_blocks=encoder_resnet_blocks,
|
|
spatial_downsample=encoder_spatial_downsample,
|
|
temporal_downsample=encoder_temporal_downsample,
|
|
mid_resnet=encoder_mid_resnet,
|
|
dropout=dropout,
|
|
resolution=resolution,
|
|
num_res_blocks=num_res_blocks,
|
|
double_z=double_z,
|
|
)
|
|
|
|
self.decoder = Decoder(
|
|
z_channels=z_channels,
|
|
hidden_size=hidden_size,
|
|
hidden_size_mult=hidden_size_mult,
|
|
attn_resolutions=attn_resolutions,
|
|
conv_in=decoder_conv_in,
|
|
conv_out=decoder_conv_out,
|
|
attention=decoder_attention,
|
|
resnet_blocks=decoder_resnet_blocks,
|
|
spatial_upsample=decoder_spatial_upsample,
|
|
temporal_upsample=decoder_temporal_upsample,
|
|
mid_resnet=decoder_mid_resnet,
|
|
dropout=dropout,
|
|
resolution=resolution,
|
|
num_res_blocks=num_res_blocks,
|
|
)
|
|
|
|
quant_conv_cls = resolve_str_to_obj(q_conv)
|
|
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
|
|
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
|
|
|
|
def encode(self, x):
|
|
if self.use_tiling and (
|
|
x.shape[-1] > self.tile_sample_min_size
|
|
or x.shape[-2] > self.tile_sample_min_size
|
|
or x.shape[-3] > self.tile_sample_min_size_t
|
|
):
|
|
return self.tiled_encode(x)
|
|
h = self.encoder(x)
|
|
moments = self.quant_conv(h)
|
|
posterior = DiagonalGaussianDistribution(moments)
|
|
return posterior
|
|
|
|
def decode(self, z):
|
|
if self.use_tiling and (
|
|
z.shape[-1] > self.tile_latent_min_size
|
|
or z.shape[-2] > self.tile_latent_min_size
|
|
or z.shape[-3] > self.tile_latent_min_size_t
|
|
):
|
|
return self.tiled_decode(z)
|
|
z = self.post_quant_conv(z)
|
|
dec = self.decoder(z)
|
|
return dec
|
|
|
|
def forward(self, input, sample_posterior=True):
|
|
posterior = self.encode(input)
|
|
if sample_posterior:
|
|
z = posterior.sample()
|
|
else:
|
|
z = posterior.mode()
|
|
dec = self.decode(z)
|
|
return dec, posterior
|
|
|
|
def get_input(self, batch, k):
|
|
x = batch[k]
|
|
if len(x.shape) == 3:
|
|
x = x[..., None]
|
|
x = x.to(memory_format=torch.contiguous_format).float()
|
|
return x
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if hasattr(self.loss, "discriminator"):
|
|
return self._training_step_gan(batch, batch_idx=batch_idx)
|
|
else:
|
|
return self._training_step(batch, batch_idx=batch_idx)
|
|
|
|
def _training_step(self, batch, batch_idx):
|
|
inputs = self.get_input(batch, "video")
|
|
reconstructions, posterior = self(inputs)
|
|
aeloss, log_dict_ae = self.loss(
|
|
inputs,
|
|
reconstructions,
|
|
posterior,
|
|
split="train",
|
|
)
|
|
self.log(
|
|
"aeloss",
|
|
aeloss,
|
|
prog_bar=True,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
|
return aeloss
|
|
|
|
def _training_step_gan(self, batch, batch_idx):
|
|
inputs = self.get_input(batch, "video")
|
|
reconstructions, posterior = self(inputs)
|
|
opt1, opt2 = self.optimizers()
|
|
|
|
# ---- AE Loss ----
|
|
aeloss, log_dict_ae = self.loss(
|
|
inputs,
|
|
reconstructions,
|
|
posterior,
|
|
0,
|
|
self.global_step,
|
|
last_layer=self.get_last_layer(),
|
|
split="train",
|
|
)
|
|
self.log(
|
|
"aeloss",
|
|
aeloss,
|
|
prog_bar=True,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
opt1.zero_grad()
|
|
self.manual_backward(aeloss)
|
|
self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm")
|
|
opt1.step()
|
|
# ---- GAN Loss ----
|
|
discloss, log_dict_disc = self.loss(
|
|
inputs,
|
|
reconstructions,
|
|
posterior,
|
|
1,
|
|
self.global_step,
|
|
last_layer=self.get_last_layer(),
|
|
split="train",
|
|
)
|
|
self.log(
|
|
"discloss",
|
|
discloss,
|
|
prog_bar=True,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
opt2.zero_grad()
|
|
self.manual_backward(discloss)
|
|
self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm")
|
|
opt2.step()
|
|
self.log_dict(
|
|
{**log_dict_ae, **log_dict_disc},
|
|
prog_bar=False,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=False,
|
|
)
|
|
|
|
def configure_optimizers(self):
|
|
from itertools import chain
|
|
|
|
lr = self.learning_rate
|
|
modules_to_train = [
|
|
self.encoder.named_parameters(),
|
|
self.decoder.named_parameters(),
|
|
self.post_quant_conv.named_parameters(),
|
|
self.quant_conv.named_parameters(),
|
|
]
|
|
params_with_time = []
|
|
params_without_time = []
|
|
for name, param in chain(*modules_to_train):
|
|
if "time" in name:
|
|
params_with_time.append(param)
|
|
else:
|
|
params_without_time.append(param)
|
|
optimizers = []
|
|
opt_ae = torch.optim.Adam(
|
|
[
|
|
{"params": params_with_time, "lr": lr},
|
|
{"params": params_without_time, "lr": lr},
|
|
],
|
|
lr=lr,
|
|
betas=(0.5, 0.9),
|
|
)
|
|
optimizers.append(opt_ae)
|
|
|
|
if hasattr(self.loss, "discriminator"):
|
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
|
|
optimizers.append(opt_disc)
|
|
|
|
return optimizers, []
|
|
|
|
def get_last_layer(self):
|
|
if hasattr(self.decoder.conv_out, "conv"):
|
|
return self.decoder.conv_out.conv.weight
|
|
else:
|
|
return self.decoder.conv_out.weight
|
|
|
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
|
for y in range(blend_extent):
|
|
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
|
y / blend_extent
|
|
)
|
|
return b
|
|
|
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
|
for x in range(blend_extent):
|
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
|
x / blend_extent
|
|
)
|
|
return b
|
|
|
|
def tiled_encode(self, x):
|
|
t = x.shape[2]
|
|
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)]
|
|
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
|
t_chunk_start_end = [[0, t]]
|
|
else:
|
|
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
|
|
if t_chunk_start_end[-1][-1] > t:
|
|
t_chunk_start_end[-1][-1] = t
|
|
elif t_chunk_start_end[-1][-1] < t:
|
|
last_start_end = [t_chunk_idx[-1], t]
|
|
t_chunk_start_end.append(last_start_end)
|
|
moments = []
|
|
for idx, (start, end) in enumerate(t_chunk_start_end):
|
|
chunk_x = x[:, :, start:end]
|
|
if idx != 0:
|
|
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
|
|
else:
|
|
moment = self.tiled_encode2d(chunk_x, return_moments=True)
|
|
moments.append(moment)
|
|
moments = torch.cat(moments, dim=2)
|
|
posterior = DiagonalGaussianDistribution(moments)
|
|
return posterior
|
|
|
|
def tiled_decode(self, x):
|
|
t = x.shape[2]
|
|
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)]
|
|
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
|
|
t_chunk_start_end = [[0, t]]
|
|
else:
|
|
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
|
|
if t_chunk_start_end[-1][-1] > t:
|
|
t_chunk_start_end[-1][-1] = t
|
|
elif t_chunk_start_end[-1][-1] < t:
|
|
last_start_end = [t_chunk_idx[-1], t]
|
|
t_chunk_start_end.append(last_start_end)
|
|
dec_ = []
|
|
for idx, (start, end) in enumerate(t_chunk_start_end):
|
|
chunk_x = x[:, :, start:end]
|
|
if idx != 0:
|
|
dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
|
|
else:
|
|
dec = self.tiled_decode2d(chunk_x)
|
|
dec_.append(dec)
|
|
dec_ = torch.cat(dec_, dim=2)
|
|
return dec_
|
|
|
|
def tiled_encode2d(self, x, return_moments=False):
|
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
|
# Split the image into 512x512 tiles and encode them separately.
|
|
rows = []
|
|
for i in range(0, x.shape[3], overlap_size):
|
|
row = []
|
|
for j in range(0, x.shape[4], overlap_size):
|
|
tile = x[
|
|
:,
|
|
:,
|
|
:,
|
|
i : i + self.tile_sample_min_size,
|
|
j : j + self.tile_sample_min_size,
|
|
]
|
|
tile = self.encoder(tile)
|
|
tile = self.quant_conv(tile)
|
|
row.append(tile)
|
|
rows.append(row)
|
|
result_rows = []
|
|
for i, row in enumerate(rows):
|
|
result_row = []
|
|
for j, tile in enumerate(row):
|
|
# blend the above tile and the left tile
|
|
# to the current tile and add the current tile to the result row
|
|
if i > 0:
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
if j > 0:
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
|
result_rows.append(torch.cat(result_row, dim=4))
|
|
|
|
moments = torch.cat(result_rows, dim=3)
|
|
posterior = DiagonalGaussianDistribution(moments)
|
|
if return_moments:
|
|
return moments
|
|
return posterior
|
|
|
|
def tiled_decode2d(self, z):
|
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
|
row_limit = self.tile_sample_min_size - blend_extent
|
|
|
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
|
# The tiles have an overlap to avoid seams between tiles.
|
|
rows = []
|
|
for i in range(0, z.shape[3], overlap_size):
|
|
row = []
|
|
for j in range(0, z.shape[4], overlap_size):
|
|
tile = z[
|
|
:,
|
|
:,
|
|
:,
|
|
i : i + self.tile_latent_min_size,
|
|
j : j + self.tile_latent_min_size,
|
|
]
|
|
tile = self.post_quant_conv(tile)
|
|
decoded = self.decoder(tile)
|
|
row.append(decoded)
|
|
rows.append(row)
|
|
result_rows = []
|
|
for i, row in enumerate(rows):
|
|
result_row = []
|
|
for j, tile in enumerate(row):
|
|
# blend the above tile and the left tile
|
|
# to the current tile and add the current tile to the result row
|
|
if i > 0:
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
if j > 0:
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
|
result_rows.append(torch.cat(result_row, dim=4))
|
|
|
|
dec = torch.cat(result_rows, dim=3)
|
|
return dec
|
|
|
|
def enable_tiling(self, use_tiling: bool = True):
|
|
self.use_tiling = use_tiling
|
|
|
|
def disable_tiling(self):
|
|
self.enable_tiling(False)
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False):
|
|
sd = torch.load(path, map_location="cpu")
|
|
print("init from " + path)
|
|
if "state_dict" in sd:
|
|
sd = sd["state_dict"]
|
|
keys = list(sd.keys())
|
|
for k in keys:
|
|
for ik in ignore_keys:
|
|
if k.startswith(ik):
|
|
print("Deleting key {} from state_dict.".format(k))
|
|
del sd[k]
|
|
self.load_state_dict(sd, strict=False)
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
inputs = self.get_input(batch, "video")
|
|
latents = self.encode(inputs).sample()
|
|
video_recon = self.decode(latents)
|
|
for idx in range(len(video_recon)):
|
|
self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10])
|
|
|
|
|
|
class CausalVAEModelWrapper(nn.Module):
|
|
def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
|
|
super(CausalVAEModelWrapper, self).__init__()
|
|
# if os.path.exists(ckpt):
|
|
# self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
|
|
self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
|
|
|
|
def encode(self, x): # b c t h w
|
|
# x = self.vae.encode(x).sample()
|
|
x = self.vae.encode(x).sample().mul_(0.18215)
|
|
return x
|
|
|
|
def decode(self, x):
|
|
# x = self.vae.decode(x)
|
|
x = self.vae.decode(x / 0.18215)
|
|
x = rearrange(x, "b c t h w -> b t c h w").contiguous()
|
|
return x
|
|
|
|
def dtype(self):
|
|
return self.vae.dtype
|
|
|
|
#
|
|
# def device(self):
|
|
# return self.vae.device
|
|
|
|
|
|
videobase_ae_stride = {
|
|
"CausalVAEModel_4x8x8": [4, 8, 8],
|
|
}
|
|
|
|
videobase_ae_channel = {
|
|
"CausalVAEModel_4x8x8": 4,
|
|
}
|
|
|
|
videobase_ae = {
|
|
"CausalVAEModel_4x8x8": CausalVAEModelWrapper,
|
|
}
|
|
|
|
|
|
ae_stride_config = {}
|
|
ae_stride_config.update(videobase_ae_stride)
|
|
|
|
ae_channel_config = {}
|
|
ae_channel_config.update(videobase_ae_channel)
|
|
|
|
|
|
def getae_wrapper(ae):
|
|
"""deprecation"""
|
|
ae = videobase_ae.get(ae, None)
|
|
assert ae is not None
|
|
return ae
|
|
|
|
|
|
def video_to_image(func):
|
|
def wrapper(self, x, *args, **kwargs):
|
|
if x.dim() == 5:
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
x = func(self, x, *args, **kwargs)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
return x
|
|
|
|
return wrapper
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class LinearAttention(Block):
|
|
def __init__(self, dim, heads=4, dim_head=32):
|
|
super().__init__()
|
|
self.heads = heads
|
|
hidden_dim = dim_head * heads
|
|
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
|
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.shape
|
|
qkv = self.to_qkv(x)
|
|
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
|
k = k.softmax(dim=-1)
|
|
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
|
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
|
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
|
return self.to_out(out)
|
|
|
|
|
|
class LinAttnBlock(LinearAttention):
|
|
"""to match AttnBlock usage"""
|
|
|
|
def __init__(self, in_channels):
|
|
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
|
|
|
|
|
class AttnBlock3D(Block):
|
|
"""Compatible with old versions, there are issues, use with caution."""
|
|
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = Normalize(in_channels)
|
|
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
b, c, t, h, w = q.shape
|
|
q = q.reshape(b * t, c, h * w)
|
|
q = q.permute(0, 2, 1) # b,hw,c
|
|
k = k.reshape(b * t, c, h * w) # b,c,hw
|
|
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
w_ = w_ * (int(c) ** (-0.5))
|
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
|
|
# attend to values
|
|
v = v.reshape(b * t, c, h * w)
|
|
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
h_ = h_.reshape(b, c, t, h, w)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
class AttnBlock3DFix(nn.Module):
|
|
"""
|
|
Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
|
|
"""
|
|
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = Normalize(in_channels)
|
|
self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
# q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
|
|
b, c, t, h, w = q.shape
|
|
q = q.permute(0, 2, 1, 3, 4)
|
|
q = q.reshape(b * t, c, h * w)
|
|
q = q.permute(0, 2, 1)
|
|
|
|
# k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
|
|
k = k.permute(0, 2, 1, 3, 4)
|
|
k = k.reshape(b * t, c, h * w)
|
|
|
|
# w: (b*t hw hw)
|
|
w_ = torch.bmm(q, k)
|
|
w_ = w_ * (int(c) ** (-0.5))
|
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
|
|
# attend to values
|
|
# v: (b c t h w) -> (b t c h w) -> (bt c hw)
|
|
# w_: (bt hw hw) -> (bt hw hw)
|
|
v = v.permute(0, 2, 1, 3, 4)
|
|
v = v.reshape(b * t, c, h * w)
|
|
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
|
|
# h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
|
|
h_ = h_.reshape(b, t, c, h, w)
|
|
h_ = h_.permute(0, 2, 1, 3, 4)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
class AttnBlock(Block):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = Normalize(in_channels)
|
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
@video_to_image
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
b, c, h, w = q.shape
|
|
q = q.reshape(b, c, h * w)
|
|
q = q.permute(0, 2, 1) # b,hw,c
|
|
k = k.reshape(b, c, h * w) # b,c,hw
|
|
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
w_ = w_ * (int(c) ** (-0.5))
|
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
|
|
# attend to values
|
|
v = v.reshape(b, c, h * w)
|
|
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
h_ = h_.reshape(b, c, h, w)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
class TemporalAttnBlock(Block):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = Normalize(in_channels)
|
|
self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
b, c, t, h, w = q.shape
|
|
q = rearrange(q, "b c t h w -> (b h w) t c")
|
|
k = rearrange(k, "b c t h w -> (b h w) c t")
|
|
v = rearrange(v, "b c t h w -> (b h w) c t")
|
|
w_ = torch.bmm(q, k)
|
|
w_ = w_ * (int(c) ** (-0.5))
|
|
w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
|
|
# attend to values
|
|
w_ = w_.permute(0, 2, 1)
|
|
h_ = torch.bmm(v, w_)
|
|
h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla"):
|
|
assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
|
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
|
print(attn_type)
|
|
if attn_type == "vanilla":
|
|
return AttnBlock(in_channels)
|
|
elif attn_type == "vanilla3D":
|
|
return AttnBlock3D(in_channels)
|
|
elif attn_type == "none":
|
|
return nn.Identity(in_channels)
|
|
else:
|
|
return LinAttnBlock(in_channels)
|
|
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, Tuple[int]] = 3,
|
|
stride: Union[int, Tuple[int]] = 1,
|
|
padding: Union[str, int, Tuple[int]] = 0,
|
|
dilation: Union[int, Tuple[int]] = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: str = "zeros",
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
groups,
|
|
bias,
|
|
padding_mode,
|
|
device,
|
|
dtype,
|
|
)
|
|
|
|
@video_to_image
|
|
def forward(self, x):
|
|
return super().forward(x)
|
|
|
|
|
|
class CausalConv3d(nn.Module):
|
|
def __init__(
|
|
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
|
|
):
|
|
super().__init__()
|
|
self.kernel_size = cast_tuple(kernel_size, 3)
|
|
self.time_kernel_size = self.kernel_size[0]
|
|
self.chan_in = chan_in
|
|
self.chan_out = chan_out
|
|
stride = kwargs.pop("stride", 1)
|
|
padding = kwargs.pop("padding", 0)
|
|
padding = list(cast_tuple(padding, 3))
|
|
padding[0] = 0
|
|
stride = cast_tuple(stride, 3)
|
|
self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
|
|
self._init_weights(init_method)
|
|
|
|
def _init_weights(self, init_method):
|
|
torch.tensor(self.kernel_size)
|
|
if init_method == "avg":
|
|
assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
|
|
assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
|
|
weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
|
|
|
|
eyes = torch.concat(
|
|
[
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
|
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
|
|
],
|
|
dim=-1,
|
|
)
|
|
weight[:, :, :, 0, 0] = eyes
|
|
|
|
self.conv.weight = nn.Parameter(
|
|
weight,
|
|
requires_grad=True,
|
|
)
|
|
elif init_method == "zero":
|
|
self.conv.weight = nn.Parameter(
|
|
torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
|
|
requires_grad=True,
|
|
)
|
|
if self.conv.bias is not None:
|
|
nn.init.constant_(self.conv.bias, 0)
|
|
|
|
def forward(self, x):
|
|
# 1 + 16 16 as video, 1 as image
|
|
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
|
|
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
|
|
return self.conv(x)
|
|
|
|
|
|
class GroupNorm(Block):
|
|
def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
|
|
|
|
def forward(self, x):
|
|
return self.norm(x)
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32):
|
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
|
|
class ActNorm(nn.Module):
|
|
def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
|
|
assert affine
|
|
super().__init__()
|
|
self.logdet = logdet
|
|
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
|
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
|
self.allow_reverse_init = allow_reverse_init
|
|
|
|
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
|
|
|
def initialize(self, input):
|
|
with torch.no_grad():
|
|
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
|
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
|
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
|
|
|
self.loc.data.copy_(-mean)
|
|
self.scale.data.copy_(1 / (std + 1e-6))
|
|
|
|
def forward(self, input, reverse=False):
|
|
if reverse:
|
|
return self.reverse(input)
|
|
if len(input.shape) == 2:
|
|
input = input[:, :, None, None]
|
|
squeeze = True
|
|
else:
|
|
squeeze = False
|
|
|
|
_, _, height, width = input.shape
|
|
|
|
if self.training and self.initialized.item() == 0:
|
|
self.initialize(input)
|
|
self.initialized.fill_(1)
|
|
|
|
h = self.scale * (input + self.loc)
|
|
|
|
if squeeze:
|
|
h = h.squeeze(-1).squeeze(-1)
|
|
|
|
if self.logdet:
|
|
log_abs = torch.log(torch.abs(self.scale))
|
|
logdet = height * width * torch.sum(log_abs)
|
|
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
|
return h, logdet
|
|
|
|
return h
|
|
|
|
def reverse(self, output):
|
|
if self.training and self.initialized.item() == 0:
|
|
if not self.allow_reverse_init:
|
|
raise RuntimeError(
|
|
"Initializing ActNorm in reverse direction is "
|
|
"disabled by default. Use allow_reverse_init=True to enable."
|
|
)
|
|
else:
|
|
self.initialize(output)
|
|
self.initialized.fill_(1)
|
|
|
|
if len(output.shape) == 2:
|
|
output = output[:, :, None, None]
|
|
squeeze = True
|
|
else:
|
|
squeeze = False
|
|
|
|
h = output / self.scale - self.loc
|
|
|
|
if squeeze:
|
|
h = h.squeeze(-1).squeeze(-1)
|
|
return h
|
|
|
|
|
|
def nonlinearity(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def cast_tuple(t, length=1):
|
|
return t if isinstance(t, tuple) else ((t,) * length)
|
|
|
|
|
|
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
|
|
n_dims = len(x.shape)
|
|
if src_dim < 0:
|
|
src_dim = n_dims + src_dim
|
|
if dest_dim < 0:
|
|
dest_dim = n_dims + dest_dim
|
|
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
|
|
dims = list(range(n_dims))
|
|
del dims[src_dim]
|
|
permutation = []
|
|
ctr = 0
|
|
for i in range(n_dims):
|
|
if i == dest_dim:
|
|
permutation.append(src_dim)
|
|
else:
|
|
permutation.append(dims[ctr])
|
|
ctr += 1
|
|
x = x.permute(permutation)
|
|
if make_contiguous:
|
|
x = x.contiguous()
|
|
return x
|
|
|
|
|
|
class Codebook(nn.Module):
|
|
def __init__(self, n_codes, embedding_dim):
|
|
super().__init__()
|
|
self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
|
|
self.register_buffer("N", torch.zeros(n_codes))
|
|
self.register_buffer("z_avg", self.embeddings.data.clone())
|
|
|
|
self.n_codes = n_codes
|
|
self.embedding_dim = embedding_dim
|
|
self._need_init = True
|
|
|
|
def _tile(self, x):
|
|
d, ew = x.shape
|
|
if d < self.n_codes:
|
|
n_repeats = (self.n_codes + d - 1) // d
|
|
std = 0.01 / np.sqrt(ew)
|
|
x = x.repeat(n_repeats, 1)
|
|
x = x + torch.randn_like(x) * std
|
|
return x
|
|
|
|
def _init_embeddings(self, z):
|
|
# z: [b, c, t, h, w]
|
|
self._need_init = False
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
|
|
y = self._tile(flat_inputs)
|
|
|
|
y.shape[0]
|
|
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
|
|
if dist.is_initialized():
|
|
dist.broadcast(_k_rand, 0)
|
|
self.embeddings.data.copy_(_k_rand)
|
|
self.z_avg.data.copy_(_k_rand)
|
|
self.N.data.copy_(torch.ones(self.n_codes))
|
|
|
|
def forward(self, z):
|
|
# z: [b, c, t, h, w]
|
|
if self._need_init and self.training:
|
|
self._init_embeddings(z)
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
|
|
distances = (
|
|
(flat_inputs**2).sum(dim=1, keepdim=True)
|
|
- 2 * flat_inputs @ self.embeddings.t()
|
|
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
|
|
)
|
|
|
|
encoding_indices = torch.argmin(distances, dim=1)
|
|
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
|
|
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
|
|
|
|
embeddings = F.embedding(encoding_indices, self.embeddings)
|
|
embeddings = shift_dim(embeddings, -1, 1)
|
|
|
|
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
|
|
|
|
# EMA codebook update
|
|
if self.training:
|
|
n_total = encode_onehot.sum(dim=0)
|
|
encode_sum = flat_inputs.t() @ encode_onehot
|
|
if dist.is_initialized():
|
|
dist.all_reduce(n_total)
|
|
dist.all_reduce(encode_sum)
|
|
|
|
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
|
|
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
|
|
|
|
n = self.N.sum()
|
|
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
|
|
encode_normalized = self.z_avg / weights.unsqueeze(1)
|
|
self.embeddings.data.copy_(encode_normalized)
|
|
|
|
y = self._tile(flat_inputs)
|
|
_k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
|
|
if dist.is_initialized():
|
|
dist.broadcast(_k_rand, 0)
|
|
|
|
usage = (self.N.view(self.n_codes, 1) >= 1).float()
|
|
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
|
|
|
|
embeddings_st = (embeddings - z).detach() + z
|
|
|
|
avg_probs = torch.mean(encode_onehot, dim=0)
|
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
|
|
|
return dict(
|
|
embeddings=embeddings_st,
|
|
encodings=encoding_indices,
|
|
commitment_loss=commitment_loss,
|
|
perplexity=perplexity,
|
|
)
|
|
|
|
def dictionary_lookup(self, encodings):
|
|
embeddings = F.embedding(encodings, self.embeddings)
|
|
return embeddings
|
|
|
|
|
|
class ResnetBlock2D(Block):
|
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.norm1 = Normalize(in_channels)
|
|
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
self.norm2 = Normalize(out_channels)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
@video_to_image
|
|
def forward(self, x):
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv1(h)
|
|
h = self.norm2(h)
|
|
h = nonlinearity(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
x = x + h
|
|
return x
|
|
|
|
|
|
class ResnetBlock3D(Block):
|
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.norm1 = Normalize(in_channels)
|
|
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
|
|
self.norm2 = Normalize(out_channels)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
|
|
else:
|
|
self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
|
|
|
|
def forward(self, x):
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv1(h)
|
|
h = self.norm2(h)
|
|
h = nonlinearity(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
return x + h
|
|
|
|
|
|
class Upsample(Block):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.with_conv = True
|
|
if self.with_conv:
|
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
@video_to_image
|
|
def forward(self, x):
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(Block):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.with_conv = True
|
|
if self.with_conv:
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
|
|
|
@video_to_image
|
|
def forward(self, x):
|
|
if self.with_conv:
|
|
pad = (0, 1, 0, 1)
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
x = self.conv(x)
|
|
else:
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
class SpatialDownsample2x(Block):
|
|
def __init__(
|
|
self,
|
|
chan_in,
|
|
chan_out,
|
|
kernel_size: Union[int, Tuple[int]] = (3, 3),
|
|
stride: Union[int, Tuple[int]] = (2, 2),
|
|
):
|
|
super().__init__()
|
|
kernel_size = cast_tuple(kernel_size, 2)
|
|
stride = cast_tuple(stride, 2)
|
|
self.chan_in = chan_in
|
|
self.chan_out = chan_out
|
|
self.kernel_size = kernel_size
|
|
self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
|
|
|
|
def forward(self, x):
|
|
pad = (0, 1, 0, 1, 0, 0)
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class SpatialUpsample2x(Block):
|
|
def __init__(
|
|
self,
|
|
chan_in,
|
|
chan_out,
|
|
kernel_size: Union[int, Tuple[int]] = (3, 3),
|
|
stride: Union[int, Tuple[int]] = (1, 1),
|
|
):
|
|
super().__init__()
|
|
self.chan_in = chan_in
|
|
self.chan_out = chan_out
|
|
self.kernel_size = kernel_size
|
|
self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
|
|
|
|
def forward(self, x):
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> b (c t) h w")
|
|
x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
|
|
x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class TimeDownsample2x(Block):
|
|
def __init__(self, chan_in, chan_out, kernel_size: int = 3):
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
|
|
|
def forward(self, x):
|
|
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
|
|
x = torch.concatenate((first_frame_pad, x), dim=2)
|
|
return self.conv(x)
|
|
|
|
|
|
class TimeUpsample2x(Block):
|
|
def __init__(self, chan_in, chan_out):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
if x.size(2) > 1:
|
|
x, x_ = x[:, :, :1], x[:, :, 1:]
|
|
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
|
x = torch.concat([x, x_], dim=2)
|
|
return x
|
|
|
|
|
|
class TimeDownsampleRes2x(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size: int = 3,
|
|
mix_factor: float = 2.0,
|
|
):
|
|
super().__init__()
|
|
self.kernel_size = cast_tuple(kernel_size, 3)
|
|
self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
|
self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
|
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
|
|
|
def forward(self, x):
|
|
alpha = torch.sigmoid(self.mix_factor)
|
|
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
|
|
x = torch.concatenate((first_frame_pad, x), dim=2)
|
|
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
|
|
|
|
|
|
class TimeUpsampleRes2x(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size: int = 3,
|
|
mix_factor: float = 2.0,
|
|
):
|
|
super().__init__()
|
|
self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
|
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
|
|
|
def forward(self, x):
|
|
alpha = torch.sigmoid(self.mix_factor)
|
|
if x.size(2) > 1:
|
|
x, x_ = x[:, :, :1], x[:, :, 1:]
|
|
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
|
x = torch.concat([x, x_], dim=2)
|
|
return alpha * x + (1 - alpha) * self.conv(x)
|
|
|
|
|
|
class TimeDownsampleResAdv2x(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size: int = 3,
|
|
mix_factor: float = 1.5,
|
|
):
|
|
super().__init__()
|
|
self.kernel_size = cast_tuple(kernel_size, 3)
|
|
self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
|
|
self.attn = TemporalAttnBlock(in_channels)
|
|
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
|
|
self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
|
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
|
|
|
def forward(self, x):
|
|
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
|
|
x = torch.concatenate((first_frame_pad, x), dim=2)
|
|
alpha = torch.sigmoid(self.mix_factor)
|
|
return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
|
|
|
|
|
|
class TimeUpsampleResAdv2x(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size: int = 3,
|
|
mix_factor: float = 1.5,
|
|
):
|
|
super().__init__()
|
|
self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
|
|
self.attn = TemporalAttnBlock(in_channels)
|
|
self.norm = Normalize(in_channels=in_channels)
|
|
self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
|
|
self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
|
|
|
|
def forward(self, x):
|
|
if x.size(2) > 1:
|
|
x, x_ = x[:, :, :1], x[:, :, 1:]
|
|
x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
|
|
x = torch.concat([x, x_], dim=2)
|
|
alpha = torch.sigmoid(self.mix_factor)
|
|
return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
|