mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class CogVideoXUpsample3D(nn.Module):
|
|
r"""
|
|
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
|
|
|
Args:
|
|
in_channels (`int`):
|
|
Number of channels in the input image.
|
|
out_channels (`int`):
|
|
Number of channels produced by the convolution.
|
|
kernel_size (`int`, defaults to `3`):
|
|
Size of the convolving kernel.
|
|
stride (`int`, defaults to `1`):
|
|
Stride of the convolution.
|
|
padding (`int`, defaults to `1`):
|
|
Padding added to all four sides of the input.
|
|
compress_time (`bool`, defaults to `False`):
|
|
Whether or not to compress the time dimension.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int = 3,
|
|
stride: int = 1,
|
|
padding: int = 1,
|
|
compress_time: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
self.compress_time = compress_time
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
if self.compress_time:
|
|
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
|
# split first frame
|
|
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
|
|
|
x_first = F.interpolate(x_first, scale_factor=2.0)
|
|
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
|
x_first = x_first[:, :, None, :, :]
|
|
inputs = torch.cat([x_first, x_rest], dim=2)
|
|
elif inputs.shape[2] > 1:
|
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
else:
|
|
inputs = inputs.squeeze(2)
|
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
inputs = inputs[:, :, None, :, :]
|
|
else:
|
|
# only interpolate 2D
|
|
b, c, t, h, w = inputs.shape
|
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
inputs = F.interpolate(inputs, scale_factor=2.0)
|
|
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
|
|
|
b, c, t, h, w = inputs.shape
|
|
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
inputs = self.conv(inputs)
|
|
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
|
|
|
return inputs
|