mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
parent
832dad94bc
commit
75e98906a3
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Tuple, Union, Optional
|
from typing import Tuple, Union, Optional
|
||||||
from diffusers.models.embeddings import get_3d_sincos_pos_embed
|
from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXPatchEmbed(nn.Module):
|
class CogVideoXPatchEmbed(nn.Module):
|
||||||
@ -132,3 +132,95 @@ class CogVideoXPatchEmbed(nn.Module):
|
|||||||
|
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
|
def get_3d_rotary_pos_embed(
|
||||||
|
embed_dim,
|
||||||
|
crops_coords,
|
||||||
|
grid_size,
|
||||||
|
temporal_size,
|
||||||
|
theta: int = 10000,
|
||||||
|
use_real: bool = True,
|
||||||
|
grid_type: str = "linspace",
|
||||||
|
max_size: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
RoPE for video tokens with 3D structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: (`int`):
|
||||||
|
The embedding dimension size, corresponding to hidden_size_head.
|
||||||
|
crops_coords (`Tuple[int]`):
|
||||||
|
The top-left and bottom-right coordinates of the crop.
|
||||||
|
grid_size (`Tuple[int]`):
|
||||||
|
The grid size of the spatial positional embedding (height, width).
|
||||||
|
temporal_size (`int`):
|
||||||
|
The size of the temporal dimension.
|
||||||
|
theta (`float`):
|
||||||
|
Scaling factor for frequency computation.
|
||||||
|
grid_type (`str`):
|
||||||
|
Whether to use "linspace" or "slice" to compute grids.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||||
|
"""
|
||||||
|
if use_real is not True:
|
||||||
|
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||||
|
|
||||||
|
if grid_type == "linspace":
|
||||||
|
start, stop = crops_coords
|
||||||
|
grid_size_h, grid_size_w = grid_size
|
||||||
|
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||||
|
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||||
|
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||||
|
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||||
|
elif grid_type == "slice":
|
||||||
|
max_h, max_w = max_size
|
||||||
|
grid_size_h, grid_size_w = grid_size
|
||||||
|
grid_h = np.arange(max_h, dtype=np.float32)
|
||||||
|
grid_w = np.arange(max_w, dtype=np.float32)
|
||||||
|
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid value passed for `grid_type`.")
|
||||||
|
|
||||||
|
# Compute dimensions for each axis
|
||||||
|
dim_t = embed_dim // 4
|
||||||
|
dim_h = embed_dim // 8 * 3
|
||||||
|
dim_w = embed_dim // 8 * 3
|
||||||
|
|
||||||
|
# Temporal frequencies
|
||||||
|
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||||
|
# Spatial frequencies for height and width
|
||||||
|
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||||
|
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||||
|
|
||||||
|
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||||
|
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||||
|
freqs_t = freqs_t[:, None, None, :].expand(
|
||||||
|
-1, grid_size_h, grid_size_w, -1
|
||||||
|
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||||
|
freqs_h = freqs_h[None, :, None, :].expand(
|
||||||
|
temporal_size, -1, grid_size_w, -1
|
||||||
|
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||||
|
freqs_w = freqs_w[None, None, :, :].expand(
|
||||||
|
temporal_size, grid_size_h, -1, -1
|
||||||
|
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||||
|
|
||||||
|
freqs = torch.cat(
|
||||||
|
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||||
|
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||||
|
freqs = freqs.view(
|
||||||
|
temporal_size * grid_size_h * grid_size_w, -1
|
||||||
|
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||||
|
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||||
|
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||||
|
|
||||||
|
if grid_type == "slice":
|
||||||
|
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
||||||
|
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
||||||
|
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
||||||
|
|
||||||
|
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||||
|
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||||
|
return cos, sin
|
||||||
@ -182,16 +182,19 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transformer
|
#transformer
|
||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
|
||||||
else:
|
else:
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
||||||
|
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
transformer.attention_mode = attention_mode
|
transformer.attention_mode = attention_mode
|
||||||
|
|
||||||
|
if "1.5" in model:
|
||||||
|
transformer.config.sample_height = 300
|
||||||
|
transformer.config.sample_width = 300
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
@ -199,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
scheduler_config = json.load(f)
|
scheduler_config = json.load(f)
|
||||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||||
|
|
||||||
# VAE
|
#VAE
|
||||||
if "Fun" in model:
|
if "Fun" in model:
|
||||||
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
if "Pose" in model:
|
if "Pose" in model:
|
||||||
@ -393,8 +396,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer_config["use_learned_positional_embeddings"] = False
|
transformer_config["use_learned_positional_embeddings"] = False
|
||||||
transformer_config["patch_size_t"] = 2
|
transformer_config["patch_size_t"] = 2
|
||||||
transformer_config["patch_bias"] = False
|
transformer_config["patch_bias"] = False
|
||||||
transformer_config["sample_height"] = 96
|
transformer_config["sample_height"] = 300
|
||||||
transformer_config["sample_width"] = 170
|
transformer_config["sample_width"] = 300
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||||
else:
|
else:
|
||||||
transformer_config["in_channels"] = 16
|
transformer_config["in_channels"] = 16
|
||||||
|
|||||||
@ -26,9 +26,10 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
|||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from diffusers.utils.torch_utils import randn_tensor
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
from diffusers.video_processor import VideoProcessor
|
from diffusers.video_processor import VideoProcessor
|
||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
#from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
||||||
|
|
||||||
|
from .embeddings import get_3d_rotary_pos_embed
|
||||||
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
@ -293,11 +294,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
|
|
||||||
p = self.transformer.config.patch_size
|
p = self.transformer.config.patch_size
|
||||||
p_t = self.transformer.config.patch_size_t or 1
|
p_t = self.transformer.config.patch_size_t
|
||||||
|
|
||||||
|
if p_t is None:
|
||||||
|
# CogVideoX 1.0 I2V
|
||||||
base_size_width = self.transformer.config.sample_width // p
|
base_size_width = self.transformer.config.sample_width // p
|
||||||
base_size_height = self.transformer.config.sample_height // p
|
base_size_height = self.transformer.config.sample_height // p
|
||||||
base_num_frames = (num_frames + p_t - 1) // p_t
|
|
||||||
|
|
||||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||||
(grid_height, grid_width), base_size_width, base_size_height
|
(grid_height, grid_width), base_size_width, base_size_height
|
||||||
@ -306,7 +308,21 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
embed_dim=self.transformer.config.attention_head_dim,
|
embed_dim=self.transformer.config.attention_head_dim,
|
||||||
crops_coords=grid_crops_coords,
|
crops_coords=grid_crops_coords,
|
||||||
grid_size=(grid_height, grid_width),
|
grid_size=(grid_height, grid_width),
|
||||||
temporal_size=base_num_frames
|
temporal_size=num_frames,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# CogVideoX 1.5 I2V
|
||||||
|
base_size_width = self.transformer.config.sample_width // p
|
||||||
|
base_size_height = self.transformer.config.sample_height // p
|
||||||
|
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||||
|
|
||||||
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
|
embed_dim=self.transformer.config.attention_head_dim,
|
||||||
|
crops_coords=None,
|
||||||
|
grid_size=(grid_height, grid_width),
|
||||||
|
temporal_size=base_num_frames,
|
||||||
|
grid_type="slice",
|
||||||
|
max_size=(base_size_height, base_size_width),
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
@ -532,7 +548,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
# 7. context schedule and temporal tiling
|
# 7. context schedule
|
||||||
if context_schedule is not None:
|
if context_schedule is not None:
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
raise NotImplementedError("Context schedule not currently supported with image conditioning")
|
||||||
@ -544,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
use_context_schedule = False
|
use_context_schedule = False
|
||||||
logger.info("Temporal tiling and context schedule disabled")
|
logger.info("Context schedule disabled")
|
||||||
# 7.5. Create rotary embeds if required
|
# 7.5. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
|
|||||||
@ -3,3 +3,4 @@ diffusers>=0.31.0
|
|||||||
accelerate>=0.33.0
|
accelerate>=0.33.0
|
||||||
einops
|
einops
|
||||||
peft
|
peft
|
||||||
|
opencv-python
|
||||||
Loading…
x
Reference in New Issue
Block a user