mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
* hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe33 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe33 Initial commit. HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings).
314 lines
12 KiB
Python
314 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
|
import comfy.ops
|
|
import comfy.ldm.models.autoencoder
|
|
import comfy.model_management
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
|
|
class RMS_norm(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
shape = (dim, 1, 1, 1)
|
|
self.scale = dim**0.5
|
|
self.gamma = nn.Parameter(torch.empty(shape))
|
|
|
|
def forward(self, x):
|
|
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
|
|
|
class DnSmpl(nn.Module):
|
|
def __init__(self, ic, oc, tds, refiner_vae, op):
|
|
super().__init__()
|
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
|
assert oc % fct == 0
|
|
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
|
self.refiner_vae = refiner_vae
|
|
|
|
self.tds = tds
|
|
self.gs = fct * ic // oc
|
|
|
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
r1 = 2 if self.tds else 1
|
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
|
|
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
|
|
|
hf = h[:, :, :1, :, :]
|
|
b, c, f, ht, wd = hf.shape
|
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
|
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
|
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
|
hf = torch.cat([hf, hf], dim=1)
|
|
|
|
h = h[:, :, 1:, :, :]
|
|
|
|
xf = x[:, :, :1, :, :]
|
|
b, ci, f, ht, wd = xf.shape
|
|
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
|
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
|
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
|
B, C, T, H, W = xf.shape
|
|
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
|
|
|
x = x[:, :, 1:, :, :]
|
|
|
|
if h.shape[2] == 0:
|
|
return hf + xf
|
|
|
|
b, c, frms, ht, wd = h.shape
|
|
nf = frms // r1
|
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
|
|
|
b, ci, frms, ht, wd = x.shape
|
|
nf = frms // r1
|
|
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
|
B, C, T, H, W = x.shape
|
|
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
|
|
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
|
h = torch.cat([hf, h], dim=2)
|
|
x = torch.cat([xf, x], dim=2)
|
|
|
|
return h + x
|
|
|
|
|
|
class UpSmpl(nn.Module):
|
|
def __init__(self, ic, oc, tus, refiner_vae, op):
|
|
super().__init__()
|
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
|
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
|
self.refiner_vae = refiner_vae
|
|
|
|
self.tus = tus
|
|
self.rp = fct * oc // ic
|
|
|
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
r1 = 2 if self.tus else 1
|
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
|
|
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
|
hf = h[:, :, :1, :, :]
|
|
b, c, f, ht, wd = hf.shape
|
|
nc = c // (2 * 2)
|
|
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
|
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
|
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
|
hf = hf[:, : hf.shape[1] // 2]
|
|
|
|
h = h[:, :, 1:, :, :]
|
|
|
|
xf = x[:, :, :1, :, :]
|
|
b, ci, f, ht, wd = xf.shape
|
|
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
|
b, c, f, ht, wd = xf.shape
|
|
nc = c // (2 * 2)
|
|
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
|
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
|
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
|
|
|
x = x[:, :, 1:, :, :]
|
|
|
|
b, c, frms, ht, wd = h.shape
|
|
nc = c // (r1 * 2 * 2)
|
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
|
|
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
|
b, c, frms, ht, wd = x.shape
|
|
nc = c // (r1 * 2 * 2)
|
|
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
|
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
|
h = torch.cat([hf, h], dim=2)
|
|
x = torch.cat([xf, x], dim=2)
|
|
|
|
return h + x
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
|
super().__init__()
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
self.ffactor_temporal = ffactor_temporal
|
|
|
|
self.refiner_vae = refiner_vae
|
|
if self.refiner_vae:
|
|
conv_op = CarriedConv3d
|
|
norm_op = RMS_norm
|
|
else:
|
|
conv_op = ops.Conv3d
|
|
norm_op = Normalize
|
|
|
|
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
|
|
|
self.down = nn.ModuleList()
|
|
ch = block_out_channels[0]
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=conv_op, norm_op=norm_op)
|
|
for j in range(num_res_blocks)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
|
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
|
ch = nxt
|
|
self.down.append(stage)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
|
|
self.norm_out = norm_op(ch)
|
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
|
|
|
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
|
|
|
def forward(self, x):
|
|
if not self.refiner_vae and x.shape[2] == 1:
|
|
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
|
|
|
if self.refiner_vae:
|
|
xl = [x[:, :, :1, :, :]]
|
|
if x.shape[2] > self.ffactor_temporal:
|
|
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
|
x = xl
|
|
else:
|
|
x = [x]
|
|
out = []
|
|
|
|
conv_carry_in = None
|
|
|
|
for i, x1 in enumerate(x):
|
|
conv_carry_out = []
|
|
if i == len(x) - 1:
|
|
conv_carry_out = None
|
|
|
|
x1 = [ x1 ]
|
|
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
|
|
|
for stage in self.down:
|
|
for blk in stage.block:
|
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
|
if hasattr(stage, 'downsample'):
|
|
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
|
|
|
out.append(x1)
|
|
conv_carry_in = conv_carry_out
|
|
|
|
out = torch_cat_if_needed(out, dim=2)
|
|
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
|
del out
|
|
|
|
b, c, t, h, w = x.shape
|
|
grp = c // (self.z_channels << 1)
|
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
|
|
|
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
|
|
|
if self.refiner_vae:
|
|
out = self.regul(out)[0]
|
|
|
|
return out
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
|
super().__init__()
|
|
block_out_channels = block_out_channels[::-1]
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
self.refiner_vae = refiner_vae
|
|
if self.refiner_vae:
|
|
conv_op = CarriedConv3d
|
|
norm_op = RMS_norm
|
|
else:
|
|
conv_op = ops.Conv3d
|
|
norm_op = Normalize
|
|
|
|
ch = block_out_channels[0]
|
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
|
|
self.up = nn.ModuleList()
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=conv_op, norm_op=norm_op)
|
|
for j in range(num_res_blocks + 1)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
|
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
|
ch = nxt
|
|
self.up.append(stage)
|
|
|
|
self.norm_out = norm_op(ch)
|
|
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
|
|
|
def forward(self, z):
|
|
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
|
|
|
if self.refiner_vae:
|
|
x = torch.split(x, 2, dim=2)
|
|
else:
|
|
x = [ x ]
|
|
out = []
|
|
|
|
conv_carry_in = None
|
|
|
|
for i, x1 in enumerate(x):
|
|
conv_carry_out = []
|
|
if i == len(x) - 1:
|
|
conv_carry_out = None
|
|
for stage in self.up:
|
|
for blk in stage.block:
|
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
|
if hasattr(stage, 'upsample'):
|
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
|
|
|
x1 = [ F.silu(self.norm_out(x1)) ]
|
|
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
|
out.append(x1)
|
|
conv_carry_in = conv_carry_out
|
|
del x
|
|
|
|
out = torch_cat_if_needed(out, dim=2)
|
|
|
|
if not self.refiner_vae:
|
|
if z.shape[-3] == 1:
|
|
out = out[:, :, -1:]
|
|
|
|
return out
|
|
|