diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 9f5e91a59..85f515f67 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm import model_management, model_patcher class SRResidualCausalBlock3D(nn.Module): diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 9f750dcc4..ddf77cd0e 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,42 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init -class NoPadConv3d(nn.Module): - def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): - super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, x): - return self.conv(x) - - -def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): - - x = xl[0] - xl.clear() - - if conv_carry_out is not None: - to_push = x[:, :, -2:, :, :].clone() - conv_carry_out.append(to_push) - - if isinstance(op, NoPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') - else: - carry_len = conv_carry_in[0].shape[2] - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') - - out = op(x) - - return out - class RMS_norm(nn.Module): def __init__(self, dim): @@ -49,7 +19,7 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -109,7 +79,7 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -163,23 +133,6 @@ class UpSmpl(nn.Module): return h + x -class HunyuanRefinerResnetBlock(ResnetBlock): - def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): - super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - - def forward(self, x, conv_carry_in=None, conv_carry_out=None): - h = x - h = [ self.swish(self.norm1(x)) ] - h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - h = [ self.dropout(self.swish(self.norm2(h))) ] - h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x+h - class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): @@ -191,7 +144,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -206,9 +159,10 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -218,9 +172,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -246,22 +200,20 @@ class Encoder(nn.Module): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None + x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out @@ -288,7 +240,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -298,9 +250,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -308,9 +260,10 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -340,7 +293,7 @@ class Decoder(nn.Module): conv_carry_out = None for stage in self.up: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) @@ -350,10 +303,7 @@ class Decoder(nn.Module): conv_carry_in = conv_carry_out del x - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index de1e01cc8..681a55db5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -89,29 +126,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -127,17 +159,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -183,23 +218,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -520,9 +555,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -535,6 +575,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -561,10 +602,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -590,15 +636,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [ x1 ] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions-1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -607,15 +680,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [ nonlinearity(h) ] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -629,12 +702,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -709,29 +788,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out