Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995)

* hunyuan upsampler: rework imports

Remove the transitive import of VideoConv3d and Resnet and takes these
from actual implementation source.

* model: remove unused give_pre_end

According to git grep, this is not used now, and was not used in the
initial commit that introduced it (see below).

This semantic is difficult to implement temporal roll VAE for (and would
defeat the purpose). Rather than implement the complex if, just delete
the unused feature.

(venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline
220afe33 (HEAD) Initial commit.
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py:                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py:        self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py:        if self.give_pre_end:

(venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master
Previous HEAD position was 220afe33 Initial commit.
HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953)
(venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre
comfy/ldm/modules/diffusionmodules/model.py:                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
comfy/ldm/modules/diffusionmodules/model.py:        self.give_pre_end = give_pre_end
comfy/ldm/modules/diffusionmodules/model.py:        if self.give_pre_end:

* move refiner VAE temporal roller to core

Move the carrying conv op to the common VAE code and give it a better
name. Roll the carry implementation logic for Resnet into the base
class and scrap the Hunyuan specific subclass.

* model: Add temporal roll to main VAE decoder

If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolloing VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).

* model: Add temporal roll to main VAE encoder

If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolling VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).
This commit is contained in:
rattus 2025-12-03 13:49:29 +10:00 committed by GitHub
parent 3f512f5659
commit 73f5649196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 174 additions and 130 deletions

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):

View File

@ -1,42 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
import comfy.model_management import comfy.model_management
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
conv_op=conv_op, norm_op=norm_op) temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch) self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = [] conv_carry_out = []
if i == len(x) - 1: if i == len(x) - 1:
conv_carry_out = None conv_carry_out = None
x1 = [ x1 ] x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down: for stage in self.down:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'): if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1) out.append(x1)
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out del out
@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
conv_op=conv_op, norm_op=norm_op) temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None conv_carry_out = None
for stage in self.up: for stage in self.up:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
del x del x
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1:

View File

@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module): class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__() super().__init__()
@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)): if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2) scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0: if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2] results = []
if t > 1: if conv_carry_in is None:
a, b = x.split((1, t - 1), dim=2) first = x[:, :, :1, :, :]
del x results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
b = interpolate_up(b, scale_factor) x = x[:, :, 1:, :, :]
else: if x.shape[2] > 0:
a = x results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else: else:
x = interpolate_up(x, scale_factor) x = interpolate_up(x, scale_factor)
if self.with_conv: if self.with_conv:
x = self.conv(x) x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x return x
@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride, stride=stride,
padding=0) padding=0)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv: if self.with_conv:
if x.ndim == 4: if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
mode = "constant" mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0) x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5: elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0) pad = (1, 1, 1, 1, 2, 0)
mode = "replicate" mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode) x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x return x
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
def forward(self, x, temb=None): def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = self.swish(h) h = [ self.swish(h) ]
h = self.conv1(h) h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None: if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h) h = self.norm2(h)
h = self.swish(h) h = self.swish(h)
h = self.dropout(h) h = [ self.dropout(h) ]
h = self.conv2(h) h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else: else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
@ -520,9 +555,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.carried = False
if conv3d: if conv3d:
conv_op = VideoConv3d if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -535,6 +575,7 @@ class Encoder(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
self.time_compress = 1
curr_res = resolution curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult) in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
@ -561,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None: if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2) stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -590,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling
h = self.conv_in(x) if self.carried:
for i_level in range(self.num_resolutions): xl = [x[:, :, :1, :, :]]
for i_block in range(self.num_res_blocks): if x.shape[2] > self.time_compress:
h = self.down[i_level].block[i_block](h, temb) tc = self.time_compress
if len(self.down[i_level].attn) > 0: xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
h = self.down[i_level].attn[i_block](h) x = xl
if i_level != self.num_resolutions-1: else:
h = self.down[i_level].downsample(h) x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)
@ -607,15 +680,15 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = [ nonlinearity(h) ]
h = self.conv_out(h) h = conv_carry_causal_3d(h, self.conv_out)
return h return h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d, conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock, resnet_op=ResnetBlock,
attn_op=AttnBlock, attn_op=AttnBlock,
@ -629,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out self.tanh_out = tanh_out
self.carried = False
if conv3d: if conv3d:
conv_op = VideoConv3d if not attn_resolutions and resnet_op == ResnetBlock:
conv_out_op = VideoConv3d conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -709,29 +788,43 @@ class Decoder(nn.Module):
temb = None temb = None
# z to block_in # z to block_in
h = self.conv_in(z) h = conv_carry_causal_3d([z], self.conv_in)
# middle # middle
h = self.mid.block_1(h, temb, **kwargs) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs) h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [ h ]
out = []
conv_carry_in = None
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i, h1 in enumerate(h):
for i_block in range(self.num_res_blocks+1): conv_carry_out = []
h = self.up[i_level].block[i_block](h, temb, **kwargs) if i == len(h) - 1:
if len(self.up[i_level].attn) > 0: conv_carry_out = None
h = self.up[i_level].attn[i_block](h, **kwargs) for i_level in reversed(range(self.num_resolutions)):
if i_level != 0: for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].upsample(h) h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0:
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
# end h1 = self.norm_out(h1)
if self.give_pre_end: h1 = [ nonlinearity(h1) ]
return h h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
if self.tanh_out:
h1 = torch.tanh(h1)
out.append(h1)
conv_carry_in = conv_carry_out
h = self.norm_out(h) out = torch_cat_if_needed(out, dim=2)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs) return out
if self.tanh_out:
h = torch.tanh(h)
return h