mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
parent
832dad94bc
commit
75e98906a3
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Tuple, Union, Optional
|
||||
from diffusers.models.embeddings import get_3d_sincos_pos_embed
|
||||
from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
@ -131,4 +131,96 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim,
|
||||
crops_coords,
|
||||
grid_size,
|
||||
temporal_size,
|
||||
theta: int = 10000,
|
||||
use_real: bool = True,
|
||||
grid_type: str = "linspace",
|
||||
max_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
grid_type (`str`):
|
||||
Whether to use "linspace" or "slice" to compute grids.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
|
||||
if grid_type == "linspace":
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
elif grid_type == "slice":
|
||||
max_h, max_w = max_size
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.arange(max_h, dtype=np.float32)
|
||||
grid_w = np.arange(max_w, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
else:
|
||||
raise ValueError("Invalid value passed for `grid_type`.")
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
freqs = freqs.view(
|
||||
temporal_size * grid_size_h * grid_size_w, -1
|
||||
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
|
||||
if grid_type == "slice":
|
||||
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
||||
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
||||
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
||||
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
@ -182,16 +182,19 @@ class DownloadAndLoadCogVideoModel:
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
# transformer
|
||||
#transformer
|
||||
if "Fun" in model:
|
||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||
|
||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||
|
||||
transformer.attention_mode = attention_mode
|
||||
|
||||
if "1.5" in model:
|
||||
transformer.config.sample_height = 300
|
||||
transformer.config.sample_width = 300
|
||||
|
||||
if block_edit is not None:
|
||||
transformer = remove_specific_blocks(transformer, block_edit)
|
||||
|
||||
@ -199,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
scheduler_config = json.load(f)
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
# VAE
|
||||
#VAE
|
||||
if "Fun" in model:
|
||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
if "Pose" in model:
|
||||
@ -393,8 +396,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer_config["use_learned_positional_embeddings"] = False
|
||||
transformer_config["patch_size_t"] = 2
|
||||
transformer_config["patch_bias"] = False
|
||||
transformer_config["sample_height"] = 96
|
||||
transformer_config["sample_width"] = 170
|
||||
transformer_config["sample_height"] = 300
|
||||
transformer_config["sample_width"] = 300
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
else:
|
||||
transformer_config["in_channels"] = 16
|
||||
|
||||
@ -26,9 +26,10 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
#from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||
|
||||
from .embeddings import get_3d_rotary_pos_embed
|
||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
|
||||
from comfy.utils import ProgressBar
|
||||
@ -293,21 +294,36 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
p = self.transformer.config.patch_size
|
||||
p_t = self.transformer.config.patch_size_t or 1
|
||||
p_t = self.transformer.config.patch_size_t
|
||||
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames
|
||||
)
|
||||
if p_t is None:
|
||||
# CogVideoX 1.0 I2V
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
)
|
||||
else:
|
||||
# CogVideoX 1.5 I2V
|
||||
base_size_width = self.transformer.config.sample_width // p
|
||||
base_size_height = self.transformer.config.sample_height // p
|
||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(base_size_height, base_size_width),
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
@ -532,7 +548,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7. context schedule and temporal tiling
|
||||
# 7. context schedule
|
||||
if context_schedule is not None:
|
||||
if image_cond_latents is not None:
|
||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||
@ -544,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
else:
|
||||
use_context_schedule = False
|
||||
logger.info("Temporal tiling and context schedule disabled")
|
||||
logger.info("Context schedule disabled")
|
||||
# 7.5. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
|
||||
@ -2,4 +2,5 @@ huggingface_hub
|
||||
diffusers>=0.31.0
|
||||
accelerate>=0.33.0
|
||||
einops
|
||||
peft
|
||||
peft
|
||||
opencv-python
|
||||
Loading…
x
Reference in New Issue
Block a user