mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
365 lines
16 KiB
Python
365 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from typing import Tuple, Union, Optional
|
|
|
|
def get_1d_rotary_pos_embed(
|
|
dim: int,
|
|
pos: Union[np.ndarray, int],
|
|
theta: float = 10000.0,
|
|
use_real=False,
|
|
linear_factor=1.0,
|
|
ntk_factor=1.0,
|
|
repeat_interleave_real=True,
|
|
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
|
):
|
|
"""
|
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
|
|
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
|
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
|
data type.
|
|
|
|
Args:
|
|
dim (`int`): Dimension of the frequency tensor.
|
|
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
|
theta (`float`, *optional*, defaults to 10000.0):
|
|
Scaling factor for frequency computation. Defaults to 10000.0.
|
|
use_real (`bool`, *optional*):
|
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
linear_factor (`float`, *optional*, defaults to 1.0):
|
|
Scaling factor for the context extrapolation. Defaults to 1.0.
|
|
ntk_factor (`float`, *optional*, defaults to 1.0):
|
|
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
|
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
|
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
|
Otherwise, they are concateanted with themselves.
|
|
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
|
the dtype of the frequency tensor.
|
|
Returns:
|
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
|
"""
|
|
assert dim % 2 == 0
|
|
|
|
if isinstance(pos, int):
|
|
pos = torch.arange(pos)
|
|
if isinstance(pos, np.ndarray):
|
|
pos = torch.from_numpy(pos) # type: ignore # [S]
|
|
|
|
theta = theta * ntk_factor
|
|
freqs = (
|
|
1.0
|
|
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
|
/ linear_factor
|
|
) # [D/2]
|
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
|
if use_real and repeat_interleave_real:
|
|
# flux, hunyuan-dit, cogvideox
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
|
return freqs_cos, freqs_sin
|
|
elif use_real:
|
|
# stable audio
|
|
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
|
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
|
return freqs_cos, freqs_sin
|
|
else:
|
|
# lumina
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
|
return freqs_cis
|
|
|
|
def get_3d_rotary_pos_embed(
|
|
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
RoPE for video tokens with 3D structure.
|
|
|
|
Args:
|
|
embed_dim: (`int`):
|
|
The embedding dimension size, corresponding to hidden_size_head.
|
|
crops_coords (`Tuple[int]`):
|
|
The top-left and bottom-right coordinates of the crop.
|
|
grid_size (`Tuple[int]`):
|
|
The grid size of the spatial positional embedding (height, width).
|
|
temporal_size (`int`):
|
|
The size of the temporal dimension.
|
|
theta (`float`):
|
|
Scaling factor for frequency computation.
|
|
|
|
Returns:
|
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
|
"""
|
|
if use_real is not True:
|
|
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
|
start, stop = crops_coords
|
|
grid_size_h, grid_size_w = grid_size
|
|
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
|
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
|
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
|
|
|
# Compute dimensions for each axis
|
|
dim_t = embed_dim // 4
|
|
dim_h = embed_dim // 8 * 3
|
|
dim_w = embed_dim // 8 * 3
|
|
|
|
# Temporal frequencies
|
|
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
|
# Spatial frequencies for height and width
|
|
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
|
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
|
|
|
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
|
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
|
freqs_t = freqs_t[:, None, None, :].expand(
|
|
-1, grid_size_h, grid_size_w, -1
|
|
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
|
freqs_h = freqs_h[None, :, None, :].expand(
|
|
temporal_size, -1, grid_size_w, -1
|
|
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
|
freqs_w = freqs_w[None, None, :, :].expand(
|
|
temporal_size, grid_size_h, -1, -1
|
|
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
|
|
|
freqs = torch.cat(
|
|
[freqs_t, freqs_h, freqs_w], dim=-1
|
|
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
|
freqs = freqs.view(
|
|
temporal_size * grid_size_h * grid_size_w, -1
|
|
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
|
return freqs
|
|
|
|
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
|
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
|
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
|
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
|
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
|
return cos, sin
|
|
|
|
def get_3d_sincos_pos_embed(
|
|
embed_dim: int,
|
|
spatial_size: Union[int, Tuple[int, int]],
|
|
temporal_size: int,
|
|
spatial_interpolation_scale: float = 1.0,
|
|
temporal_interpolation_scale: float = 1.0,
|
|
) -> np.ndarray:
|
|
r"""
|
|
Args:
|
|
embed_dim (`int`):
|
|
spatial_size (`int` or `Tuple[int, int]`):
|
|
temporal_size (`int`):
|
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
|
"""
|
|
if embed_dim % 4 != 0:
|
|
raise ValueError("`embed_dim` must be divisible by 4")
|
|
if isinstance(spatial_size, int):
|
|
spatial_size = (spatial_size, spatial_size)
|
|
|
|
embed_dim_spatial = 3 * embed_dim // 4
|
|
embed_dim_temporal = embed_dim // 4
|
|
|
|
# 1. Spatial
|
|
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
|
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
grid = np.stack(grid, axis=0)
|
|
|
|
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
|
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
|
|
|
# 2. Temporal
|
|
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
|
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
|
|
|
# 3. Concat
|
|
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
|
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
|
|
|
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
|
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
|
|
|
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
|
return pos_embed
|
|
|
|
|
|
def get_2d_sincos_pos_embed(
|
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
|
):
|
|
"""
|
|
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
"""
|
|
if isinstance(grid_size, int):
|
|
grid_size = (grid_size, grid_size)
|
|
|
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
grid = np.stack(grid, axis=0)
|
|
|
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
if cls_token and extra_tokens > 0:
|
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
|
return pos_embed
|
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be divisible by 2")
|
|
|
|
# use half of dimensions to encode grid_h
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
return emb
|
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
"""
|
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
|
"""
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be divisible by 2")
|
|
|
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
|
omega /= embed_dim / 2.0
|
|
omega = 1.0 / 10000**omega # (D/2,)
|
|
|
|
pos = pos.reshape(-1) # (M,)
|
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
|
|
emb_sin = np.sin(out) # (M, D/2)
|
|
emb_cos = np.cos(out) # (M, D/2)
|
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
return emb
|
|
|
|
class CogVideoXPatchEmbed(nn.Module):
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 2,
|
|
patch_size_t: Optional[int] = None,
|
|
in_channels: int = 16,
|
|
embed_dim: int = 1920,
|
|
text_embed_dim: int = 4096,
|
|
bias: bool = True,
|
|
sample_width: int = 90,
|
|
sample_height: int = 60,
|
|
sample_frames: int = 49,
|
|
temporal_compression_ratio: int = 4,
|
|
max_text_seq_length: int = 226,
|
|
spatial_interpolation_scale: float = 1.875,
|
|
temporal_interpolation_scale: float = 1.0,
|
|
use_positional_embeddings: bool = True,
|
|
use_learned_positional_embeddings: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
self.patch_size_t = patch_size_t
|
|
self.embed_dim = embed_dim
|
|
self.sample_height = sample_height
|
|
self.sample_width = sample_width
|
|
self.sample_frames = sample_frames
|
|
self.temporal_compression_ratio = temporal_compression_ratio
|
|
self.max_text_seq_length = max_text_seq_length
|
|
self.spatial_interpolation_scale = spatial_interpolation_scale
|
|
self.temporal_interpolation_scale = temporal_interpolation_scale
|
|
self.use_positional_embeddings = use_positional_embeddings
|
|
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
|
|
|
if patch_size_t is None:
|
|
# CogVideoX 1.0 checkpoints
|
|
self.proj = nn.Conv2d(
|
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
|
)
|
|
else:
|
|
# CogVideoX 1.5 checkpoints
|
|
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
|
|
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
|
|
|
if use_positional_embeddings or use_learned_positional_embeddings:
|
|
persistent = use_learned_positional_embeddings
|
|
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
|
|
|
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
|
post_patch_height = sample_height // self.patch_size
|
|
post_patch_width = sample_width // self.patch_size
|
|
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
|
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
|
|
|
pos_embedding = get_3d_sincos_pos_embed(
|
|
self.embed_dim,
|
|
(post_patch_width, post_patch_height),
|
|
post_time_compression_frames,
|
|
self.spatial_interpolation_scale,
|
|
self.temporal_interpolation_scale,
|
|
)
|
|
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
|
joint_pos_embedding = torch.zeros(
|
|
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
|
)
|
|
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
|
|
|
return joint_pos_embedding
|
|
|
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
r"""
|
|
Args:
|
|
text_embeds (`torch.Tensor`):
|
|
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
|
image_embeds (`torch.Tensor`):
|
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
|
"""
|
|
text_embeds = self.text_proj(text_embeds)
|
|
|
|
batch_size, num_frames, channels, height, width = image_embeds.shape
|
|
|
|
if self.patch_size_t is None:
|
|
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
|
image_embeds = self.proj(image_embeds)
|
|
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
|
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
|
else:
|
|
p = self.patch_size
|
|
p_t = self.patch_size_t
|
|
|
|
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
|
image_embeds = image_embeds.reshape(
|
|
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
|
)
|
|
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
|
image_embeds = self.proj(image_embeds)
|
|
|
|
embeds = torch.cat(
|
|
[text_embeds, image_embeds], dim=1
|
|
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
|
|
|
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
|
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
|
|
raise ValueError(
|
|
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
|
|
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
|
|
)
|
|
|
|
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
|
|
|
if (
|
|
self.sample_height != height
|
|
or self.sample_width != width
|
|
or self.sample_frames != pre_time_compression_frames
|
|
):
|
|
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
|
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
|
else:
|
|
pos_embedding = self.pos_embedding
|
|
|
|
embeds = embeds + pos_embedding
|
|
|
|
return embeds
|
|
|