rotary embed fix

25a9e1c567
This commit is contained in:
kijai 2024-11-15 02:37:44 +02:00
parent 832dad94bc
commit 75e98906a3
4 changed files with 137 additions and 25 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from typing import Tuple, Union, Optional from typing import Tuple, Union, Optional
from diffusers.models.embeddings import get_3d_sincos_pos_embed from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed
class CogVideoXPatchEmbed(nn.Module): class CogVideoXPatchEmbed(nn.Module):
@ -132,3 +132,95 @@ class CogVideoXPatchEmbed(nn.Module):
return embeds return embeds
def get_3d_rotary_pos_embed(
embed_dim,
crops_coords,
grid_size,
temporal_size,
theta: int = 10000,
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
grid_type (`str`):
Whether to use "linspace" or "slice" to compute grids.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
if grid_type == "linspace":
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
if grid_type == "slice":
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin

View File

@ -182,16 +182,19 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
) )
# transformer #transformer
if "Fun" in model: if "Fun" in model:
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder) transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder=subfolder)
else: else:
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = transformer.to(dtype).to(transformer_load_device) transformer = transformer.to(dtype).to(transformer_load_device)
transformer.attention_mode = attention_mode transformer.attention_mode = attention_mode
if "1.5" in model:
transformer.config.sample_height = 300
transformer.config.sample_width = 300
if block_edit is not None: if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit) transformer = remove_specific_blocks(transformer, block_edit)
@ -199,7 +202,7 @@ class DownloadAndLoadCogVideoModel:
scheduler_config = json.load(f) scheduler_config = json.load(f)
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
# VAE #VAE
if "Fun" in model: if "Fun" in model:
vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
if "Pose" in model: if "Pose" in model:
@ -393,8 +396,8 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer_config["use_learned_positional_embeddings"] = False transformer_config["use_learned_positional_embeddings"] = False
transformer_config["patch_size_t"] = 2 transformer_config["patch_size_t"] = 2
transformer_config["patch_bias"] = False transformer_config["patch_bias"] = False
transformer_config["sample_height"] = 96 transformer_config["sample_height"] = 300
transformer_config["sample_width"] = 170 transformer_config["sample_width"] = 300
transformer = CogVideoXTransformer3DModel.from_config(transformer_config) transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else: else:
transformer_config["in_channels"] = 16 transformer_config["in_channels"] = 16

View File

@ -26,9 +26,10 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor from diffusers.video_processor import VideoProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed #from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.loaders import CogVideoXLoraLoaderMixin from diffusers.loaders import CogVideoXLoraLoaderMixin
from .embeddings import get_3d_rotary_pos_embed
from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
@ -293,11 +294,12 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
if p_t is None:
# CogVideoX 1.0 I2V
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
@ -306,7 +308,21 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=base_num_frames temporal_size=num_frames,
)
else:
# CogVideoX 1.5 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
@ -532,7 +548,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7. context schedule and temporal tiling # 7. context schedule
if context_schedule is not None: if context_schedule is not None:
if image_cond_latents is not None: if image_cond_latents is not None:
raise NotImplementedError("Context schedule not currently supported with image conditioning") raise NotImplementedError("Context schedule not currently supported with image conditioning")
@ -544,7 +560,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else: else:
use_context_schedule = False use_context_schedule = False
logger.info("Temporal tiling and context schedule disabled") logger.info("Context schedule disabled")
# 7.5. Create rotary embeds if required # 7.5. Create rotary embeds if required
image_rotary_emb = ( image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)

View File

@ -3,3 +3,4 @@ diffusers>=0.31.0
accelerate>=0.33.0 accelerate>=0.33.0
einops einops
peft peft
opencv-python