mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995)
* hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe33 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe33 Initial commit. HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings).
This commit is contained in:
parent
3f512f5659
commit
73f5649196
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
|
||||
@ -1,42 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class NoPadConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
if isinstance(op, NoPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
|
||||
|
||||
return h + x
|
||||
|
||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = [ self.swish(self.norm1(x)) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||
@ -191,7 +144,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = NoPadConv3d
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -206,9 +159,10 @@ class Encoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@ -218,9 +172,9 @@ class Encoder(nn.Module):
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
@ -246,22 +200,20 @@ class Encoder(nn.Module):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
@ -288,7 +240,7 @@ class Decoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = NoPadConv3d
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -298,9 +250,9 @@ class Decoder(nn.Module):
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
@ -308,9 +260,10 @@ class Decoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@ -340,7 +293,7 @@ class Decoder(nn.Module):
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
@ -350,10 +303,7 @@ class Decoder(nn.Module):
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
|
||||
@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class CarriedConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if isinstance(op, CarriedConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
@ -89,29 +126,24 @@ class Upsample(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
results = []
|
||||
if conv_carry_in is None:
|
||||
first = x[:, :, :1, :, :]
|
||||
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||
x = x[:, :, 1:, :, :]
|
||||
if x.shape[2] > 0:
|
||||
results.append(interpolate_up(x, scale_factor))
|
||||
x = torch_cat_if_needed(results, dim=2)
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
return x
|
||||
|
||||
|
||||
@ -127,17 +159,20 @@ class Downsample(nn.Module):
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
if self.with_conv:
|
||||
if x.ndim == 4:
|
||||
if isinstance(self.conv, CarriedConv3d):
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
elif x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
x = self.conv(x)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
h = [ self.swish(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
h = [ self.dropout(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
@ -520,9 +555,14 @@ class Encoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
if not attn_resolutions:
|
||||
conv_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -535,6 +575,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.time_compress = 1
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
@ -561,10 +602,15 @@ class Encoder(nn.Module):
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
else:
|
||||
self.time_compress *= 2
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
if time_compress is not None:
|
||||
self.time_compress = time_compress
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
@ -590,15 +636,42 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
if self.carried:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.time_compress:
|
||||
tc = self.time_compress
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
# downsampling
|
||||
x1 = [ x1 ]
|
||||
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.down[i_level].attn[i_block](h1)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = torch_cat_if_needed(out, dim=2)
|
||||
del out
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
@ -607,15 +680,15 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = [ nonlinearity(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv_out)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
@ -629,12 +702,18 @@ class Decoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||
conv_op = CarriedConv3d
|
||||
conv_out_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -709,29 +788,43 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
h = conv_carry_causal_3d([z], self.conv_in)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
if self.carried:
|
||||
h = torch.split(h, 2, dim=2)
|
||||
else:
|
||||
h = [ h ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
for i, h1 in enumerate(h):
|
||||
conv_carry_out = []
|
||||
if i == len(h) - 1:
|
||||
conv_carry_out = None
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||
if i_level != 0:
|
||||
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
h1 = self.norm_out(h1)
|
||||
h1 = [ nonlinearity(h1) ]
|
||||
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
if self.tanh_out:
|
||||
h1 = torch.tanh(h1)
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
return out
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user