import torch import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init class RMS_norm(nn.Module): def __init__(self, dim): super().__init__() shape = (dim, 1, 1, 1) self.scale = dim**0.5 self.gamma = nn.Parameter(torch.empty(shape)) def forward(self, x): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1) self.refiner_vae = refiner_vae self.tds = tds self.gs = fct * ic // oc def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tds else 1 h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) if self.tds and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.permute(0, 4, 6, 1, 2, 3, 5) hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = torch.cat([hf, hf], dim=1) h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2) xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) B, C, T, H, W = xf.shape xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2) x = x[:, :, 1:, :, :] if h.shape[2] == 0: return hf + xf b, c, frms, ht, wd = h.shape nf = frms // r1 h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) b, ci, frms, ht, wd = x.shape nf = frms // r1 x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) x = x.permute(0, 3, 5, 7, 1, 2, 4, 6) x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) B, C, T, H, W = x.shape x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) if self.tds and self.refiner_vae and conv_carry_in is None: h = torch.cat([hf, h], dim=2) x = torch.cat([xf, x], dim=2) return h + x class UpSmpl(nn.Module): def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) self.refiner_vae = refiner_vae self.tus = tus self.rp = fct * oc // ic def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tus else 1 h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) if self.tus and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) hf = hf.reshape(b, 2, 2, nc, f, ht, wd) hf = hf.permute(0, 3, 4, 5, 1, 6, 2) hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf[:, : hf.shape[1] // 2] h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) b, c, f, ht, wd = xf.shape nc = c // (2 * 2) xf = xf.reshape(b, 2, 2, nc, f, ht, wd) xf = xf.permute(0, 3, 4, 5, 1, 6, 2) xf = xf.reshape(b, nc, f, ht * 2, wd * 2) x = x[:, :, 1:, :, :] b, c, frms, ht, wd = h.shape nc = c // (r1 * 2 * 2) h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) x = x.repeat_interleave(repeats=self.rp, dim=1) b, c, frms, ht, wd = x.shape nc = c // (r1 * 2 * 2) x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) if self.tus and self.refiner_vae and conv_carry_in is None: h = torch.cat([hf, h], dim=2) x = torch.cat([xf, x], dim=2) return h + x class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): super().__init__() self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks self.ffactor_temporal = ffactor_temporal self.refiner_vae = refiner_vae if self.refiner_vae: conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d norm_op = Normalize self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1) self.down = nn.ModuleList() ch = block_out_channels[0] depth = (ffactor_spatial >> 1).bit_length() depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.down.append(stage) self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): if not self.refiner_vae and x.shape[2] == 1: x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) if self.refiner_vae: xl = [x[:, :, :1, :, :]] if x.shape[2] > self.ffactor_temporal: xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2) x = xl else: x = [x] out = [] conv_carry_in = None for i, x1 in enumerate(x): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out b, c, t, h, w = x.shape grp = c // (self.z_channels << 1) skip = x.view(b, c // grp, grp, t, h, w).mean(2) out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip if self.refiner_vae: out = self.regul(out)[0] return out class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): super().__init__() block_out_channels = block_out_channels[::-1] self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks self.refiner_vae = refiner_vae if self.refiner_vae: conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d norm_op = Normalize ch = block_out_channels[0] self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() depth_temporal = (ffactor_temporal >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.up.append(stage) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) if self.refiner_vae: x = torch.split(x, 2, dim=2) else: x = [ x ] out = [] conv_carry_in = None for i, x1 in enumerate(x): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None for stage in self.up: for blk in stage.block: x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) x1 = [ F.silu(self.norm_out(x1)) ] x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out del x out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: out = out[:, :, -1:] return out