diff --git a/embeddings.py b/embeddings.py index 908c67f..111ba04 100644 --- a/embeddings.py +++ b/embeddings.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import numpy as np from typing import Tuple, Union, Optional -from diffusers.models.embeddings import get_3d_sincos_pos_embed +from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed class CogVideoXPatchEmbed(nn.Module): @@ -131,4 +131,96 @@ class CogVideoXPatchEmbed(nn.Module): embeds = embeds + pos_embedding return embeds - \ No newline at end of file + +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin \ No newline at end of file diff --git a/model_loading.py b/model_loading.py index 532d6aa..45b6c1b 100644 --- a/model_loading.py +++ b/model_loading.py @@ -182,16 +182,19 @@ class DownloadAndLoadCogVideoModel: local_dir_use_symlinks=False, ) - # transformer + #transformer if "Fun" in model: transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder) else: transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = transformer.to(dtype).to(transformer_load_device) - transformer.attention_mode = attention_mode + if "1.5" in model: + transformer.config.sample_height = 300 + transformer.config.sample_width = 300 + if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit) @@ -199,7 +202,7 @@ class DownloadAndLoadCogVideoModel: scheduler_config = json.load(f) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) - # VAE + #VAE if "Fun" in model: vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) if "Pose" in model: @@ -393,8 +396,8 @@ class DownloadAndLoadCogVideoGGUFModel: transformer_config["use_learned_positional_embeddings"] = False transformer_config["patch_size_t"] = 2 transformer_config["patch_bias"] = False - transformer_config["sample_height"] = 96 - transformer_config["sample_width"] = 170 + transformer_config["sample_height"] = 300 + transformer_config["sample_width"] = 300 transformer = CogVideoXTransformer3DModel.from_config(transformer_config) else: transformer_config["in_channels"] = 16 diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 09e9103..472a308 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -26,9 +26,10 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from diffusers.models.embeddings import get_3d_rotary_pos_embed +#from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.loaders import CogVideoXLoraLoaderMixin +from .embeddings import get_3d_rotary_pos_embed from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel from comfy.utils import ProgressBar @@ -293,21 +294,36 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size - p_t = self.transformer.config.patch_size_t or 1 + p_t = self.transformer.config.patch_size_t - base_size_width = self.transformer.config.sample_width // p - base_size_height = self.transformer.config.sample_height // p - base_num_frames = (num_frames + p_t - 1) // p_t - - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames - ) + if p_t is None: + # CogVideoX 1.0 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) @@ -532,7 +548,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - # 7. context schedule and temporal tiling + # 7. context schedule if context_schedule is not None: if image_cond_latents is not None: raise NotImplementedError("Context schedule not currently supported with image conditioning") @@ -544,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): else: use_context_schedule = False - logger.info("Temporal tiling and context schedule disabled") + logger.info("Context schedule disabled") # 7.5. Create rotary embeds if required image_rotary_emb = ( self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) diff --git a/requirements.txt b/requirements.txt index 2b24b6a..8ab8109 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ huggingface_hub diffusers>=0.31.0 accelerate>=0.33.0 einops -peft \ No newline at end of file +peft +opencv-python \ No newline at end of file