mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
298 lines
9.5 KiB
Python
298 lines
9.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, reduce
|
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D average pooling module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.AvgPool1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.AvgPool2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.AvgPool3d(*args, **kwargs)
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
"""
|
|
Create a 1D, 2D, or 3D convolution module.
|
|
"""
|
|
if dims == 1:
|
|
return nn.Conv1d(*args, **kwargs)
|
|
elif dims == 2:
|
|
return nn.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return nn.Conv3d(*args, **kwargs)
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
stride = 2 if dims != 3 else (1, 2, 2)
|
|
if use_conv:
|
|
self.op = conv_nd(
|
|
dims,
|
|
self.channels,
|
|
self.out_channels,
|
|
3,
|
|
stride=stride,
|
|
padding=padding,
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
return self.op(x)
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
|
super().__init__()
|
|
ps = ksize // 2
|
|
if in_c != out_c or sk == False:
|
|
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
|
else:
|
|
# print('n_in')
|
|
self.in_conv = None
|
|
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
|
self.act = nn.ReLU()
|
|
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
|
self.bn1 = nn.BatchNorm2d(out_c)
|
|
self.bn2 = nn.BatchNorm2d(out_c)
|
|
if sk == False:
|
|
# self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) # edit by zhouxiawang
|
|
self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
|
else:
|
|
self.skep = None
|
|
|
|
self.down = down
|
|
if self.down == True:
|
|
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
|
|
|
def forward(self, x):
|
|
if self.down == True:
|
|
x = self.down_opt(x)
|
|
if self.in_conv is not None: # edit
|
|
x = self.in_conv(x)
|
|
|
|
h = self.bn1(x)
|
|
h = self.act(h)
|
|
h = self.block1(h)
|
|
h = self.bn2(h)
|
|
h = self.act(h)
|
|
h = self.block2(h)
|
|
if self.skep is not None:
|
|
return h + self.skep(x)
|
|
else:
|
|
return h + x
|
|
|
|
|
|
class VAESpatialEmulator(nn.Module):
|
|
def __init__(self, kernel_size=(8, 8)):
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: torch.Tensor: shape [B C T H W]
|
|
"""
|
|
Hp, Wp = self.kernel_size
|
|
H, W = x.shape[-2], x.shape[-1]
|
|
valid_h = H - H % Hp
|
|
valid_w = W - W % Wp
|
|
x = x[..., :valid_h, :valid_w]
|
|
x = rearrange(
|
|
x,
|
|
"B C T (Nh Hp) (Nw Wp) -> B (Hp Wp C) T Nh Nw",
|
|
Hp=Hp,
|
|
Wp=Wp,
|
|
)
|
|
return x
|
|
|
|
|
|
class VAETemporalEmulator(nn.Module):
|
|
def __init__(self, micro_frame_size, kernel_size=4):
|
|
super().__init__()
|
|
self.micro_frame_size = micro_frame_size
|
|
self.kernel_size = kernel_size
|
|
|
|
def forward(self, x_z):
|
|
"""
|
|
x_z: torch.Tensor: shape [B C T H W]
|
|
"""
|
|
|
|
z_list = []
|
|
for i in range(0, x_z.shape[2], self.micro_frame_size):
|
|
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
|
|
z_list.append(x_z_bs[:, :, 0:1])
|
|
x_z_bs = x_z_bs[:, :, 1:]
|
|
t_valid = x_z_bs.shape[2] - x_z_bs.shape[2] % self.kernel_size
|
|
x_z_bs = x_z_bs[:, :, :t_valid]
|
|
x_z_bs = reduce(x_z_bs, "B C (T n) H W -> B C T H W", n=self.kernel_size, reduction="mean")
|
|
z_list.append(x_z_bs)
|
|
z = torch.cat(z_list, dim=2)
|
|
return z
|
|
|
|
|
|
class TrajExtractor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vae_downsize=(4, 8, 8),
|
|
patch_size=2,
|
|
channels=[320, 640, 1280, 1280],
|
|
nums_rb=3,
|
|
cin=2,
|
|
ksize=3,
|
|
sk=False,
|
|
use_conv=True,
|
|
):
|
|
super(TrajExtractor, self).__init__()
|
|
self.vae_downsize = vae_downsize
|
|
# self.vae_spatial_emulator = VAESpatialEmulator(kernel_size=vae_downsize[-2:])
|
|
self.downsize_patchify = nn.PixelUnshuffle(patch_size)
|
|
self.patch_size = (1, patch_size, patch_size)
|
|
self.channels = channels
|
|
self.nums_rb = nums_rb
|
|
self.body = []
|
|
for i in range(len(channels)):
|
|
for j in range(nums_rb):
|
|
if (i != 0) and (j == 0):
|
|
self.body.append(
|
|
ResnetBlock(
|
|
channels[i - 1],
|
|
channels[i],
|
|
down=False,
|
|
ksize=ksize,
|
|
sk=sk,
|
|
use_conv=use_conv,
|
|
)
|
|
)
|
|
else:
|
|
self.body.append(
|
|
ResnetBlock(
|
|
channels[i],
|
|
channels[i],
|
|
down=False,
|
|
ksize=ksize,
|
|
sk=sk,
|
|
use_conv=use_conv,
|
|
)
|
|
)
|
|
self.body = nn.ModuleList(self.body)
|
|
cin_ = cin * patch_size**2
|
|
self.conv_in = nn.Conv2d(cin_, channels[0], 3, 1, 1)
|
|
|
|
# Initialize weights
|
|
def conv_init(module):
|
|
if isinstance(module, (nn.Conv2d, nn.Conv1d)):
|
|
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
self.apply(conv_init)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: torch.Tensor: shape [B C T H W]
|
|
"""
|
|
# downsize
|
|
T, H, W = x.shape[-3:]
|
|
if W % self.patch_size[2] != 0:
|
|
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
|
if H % self.patch_size[1] != 0:
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
|
if T % self.patch_size[0] != 0:
|
|
x = F.pad(
|
|
x,
|
|
(0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0]),
|
|
)
|
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
|
x = self.downsize_patchify(x)
|
|
|
|
# extract features
|
|
features = []
|
|
x = self.conv_in(x)
|
|
for i in range(len(self.channels)):
|
|
for j in range(self.nums_rb):
|
|
idx = i * self.nums_rb + j
|
|
x = self.body[idx](x)
|
|
features.append(x)
|
|
|
|
return features
|
|
|
|
|
|
class FloatGroupNorm(nn.GroupNorm):
|
|
def forward(self, x):
|
|
return super().forward(x.to(self.bias.dtype)).type(x.dtype)
|
|
|
|
|
|
def zero_module(module):
|
|
"""
|
|
Zero out the parameters of a module and return it.
|
|
"""
|
|
for p in module.parameters():
|
|
p.detach().zero_()
|
|
return module
|
|
|
|
|
|
class MGF(nn.Module):
|
|
def __init__(self, flow_in_channel=128, out_channels=1152):
|
|
super().__init__()
|
|
self.out_channels = out_channels
|
|
self.flow_gamma_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
|
|
self.flow_gamma_temporal = zero_module(
|
|
nn.Conv1d(
|
|
self.out_channels // 4,
|
|
self.out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
padding_mode="replicate",
|
|
)
|
|
)
|
|
self.flow_beta_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
|
|
self.flow_beta_temporal = zero_module(
|
|
nn.Conv1d(
|
|
self.out_channels // 4,
|
|
self.out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
padding_mode="replicate",
|
|
)
|
|
)
|
|
self.flow_cond_norm = FloatGroupNorm(32, self.out_channels)
|
|
|
|
def forward(self, h, flow, T):
|
|
if flow is not None:
|
|
gamma_flow = self.flow_gamma_spatial(flow)
|
|
beta_flow = self.flow_beta_spatial(flow)
|
|
_, _, hh, wh = beta_flow.shape
|
|
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
|
|
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
|
|
gamma_flow = self.flow_gamma_temporal(gamma_flow)
|
|
beta_flow = self.flow_beta_temporal(beta_flow)
|
|
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
|
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
|
|
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
|
|
return h
|