From 9aab678a9eba983f15189cf05b7a09a8ea5e2f13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 9 Nov 2024 03:15:21 +0200 Subject: [PATCH] test --- custom_cogvideox_transformer_3d.py | 38 ++--- embeddings.py | 233 +---------------------------- pipeline_cogvideox.py | 3 +- 3 files changed, 16 insertions(+), 258 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index ebad39e..c751d13 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -109,37 +109,28 @@ class CogVideoXAttnProcessor2_0: if attn.norm_k is not None: key = attn.norm_k(key) + # Apply RoPE if needed if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb - has_nan = torch.isnan(query).any() - if has_nan: - raise ValueError(f"query before rope has nan: {has_nan}") - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - #if SAGEATTN_IS_AVAILABLE: - # hidden_states = sageattn(query, key, value, is_causal=False) - #else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - has_nan = torch.isnan(hidden_states).any() - if has_nan: - raise ValueError(f"hs after scaled_dot_product_attention has nan: {has_nan}") - has_inf = torch.isinf(hidden_states).any() - if has_inf: - raise ValueError(f"hs after scaled_dot_product_attention has inf: {has_inf}") + if SAGEATTN_IS_AVAILABLE: + hidden_states = sageattn(query, key, value, is_causal=False) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # linear proj hidden_states = attn.to_out[0](hidden_states) - has_nan = torch.isnan(hidden_states).any() - # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) @@ -322,7 +313,6 @@ class CogVideoXBlock(nn.Module): norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) - # Tora Motion-guidance Fuser if video_flow_feature is not None: H, W = video_flow_feature.shape[-2:] @@ -747,8 +737,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = hidden_states.reshape( batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) - output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - output = output[:, remaining_frames:] + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + output = output[:, remaining_frames:] (bb, tt, cc, hh, ww) = output.shape cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) @@ -770,7 +760,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = torch.cat([output, recovered_uncond]) else: for i, block in enumerate(self.transformer_blocks): - print("block", i) + #print("block", i) hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -820,8 +810,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): output = hidden_states.reshape( batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) - output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - output = output[:, remaining_frames:] + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + output = output[:, remaining_frames:] if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, tt, cc, hh, ww) = output.shape diff --git a/embeddings.py b/embeddings.py index bc3bf7f..908c67f 100644 --- a/embeddings.py +++ b/embeddings.py @@ -2,239 +2,8 @@ import torch import torch.nn as nn import numpy as np from typing import Tuple, Union, Optional +from diffusers.models.embeddings import get_3d_sincos_pos_embed -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[np.ndarray, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, - freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end - index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 - data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): - the dtype of the frequency tensor. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 - - if isinstance(pos, int): - pos = torch.arange(pos) - if isinstance(pos, np.ndarray): - pos = torch.from_numpy(pos) # type: ignore # [S] - - theta = theta * ntk_factor - freqs = ( - 1.0 - / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) - / linear_factor - ) # [D/2] - freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] - return freqs_cos, freqs_sin - elif use_real: - # stable audio - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] - return freqs_cos, freqs_sin - else: - # lumina - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis - -def get_3d_rotary_pos_embed( - embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - RoPE for video tokens with 3D structure. - - Args: - embed_dim: (`int`): - The embedding dimension size, corresponding to hidden_size_head. - crops_coords (`Tuple[int]`): - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the spatial positional embedding (height, width). - temporal_size (`int`): - The size of the temporal dimension. - theta (`float`): - Scaling factor for frequency computation. - - Returns: - `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. - """ - if use_real is not True: - raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") - start, stop = crops_coords - grid_size_h, grid_size_w = grid_size - grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) - - # Compute dimensions for each axis - dim_t = embed_dim // 4 - dim_h = embed_dim // 8 * 3 - dim_w = embed_dim // 8 * 3 - - # Temporal frequencies - freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) - # Spatial frequencies for height and width - freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) - freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) - - # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor - def combine_time_height_width(freqs_t, freqs_h, freqs_w): - freqs_t = freqs_t[:, None, None, :].expand( - -1, grid_size_h, grid_size_w, -1 - ) # temporal_size, grid_size_h, grid_size_w, dim_t - freqs_h = freqs_h[None, :, None, :].expand( - temporal_size, -1, grid_size_w, -1 - ) # temporal_size, grid_size_h, grid_size_2, dim_h - freqs_w = freqs_w[None, None, :, :].expand( - temporal_size, grid_size_h, -1, -1 - ) # temporal_size, grid_size_h, grid_size_2, dim_w - - freqs = torch.cat( - [freqs_t, freqs_h, freqs_w], dim=-1 - ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) - freqs = freqs.view( - temporal_size * grid_size_h * grid_size_w, -1 - ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) - return freqs - - t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t - h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h - w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w - cos = combine_time_height_width(t_cos, h_cos, w_cos) - sin = combine_time_height_width(t_sin, h_sin, w_sin) - return cos, sin - -def get_3d_sincos_pos_embed( - embed_dim: int, - spatial_size: Union[int, Tuple[int, int]], - temporal_size: int, - spatial_interpolation_scale: float = 1.0, - temporal_interpolation_scale: float = 1.0, -) -> np.ndarray: - r""" - Args: - embed_dim (`int`): - spatial_size (`int` or `Tuple[int, int]`): - temporal_size (`int`): - spatial_interpolation_scale (`float`, defaults to 1.0): - temporal_interpolation_scale (`float`, defaults to 1.0): - """ - if embed_dim % 4 != 0: - raise ValueError("`embed_dim` must be divisible by 4") - if isinstance(spatial_size, int): - spatial_size = (spatial_size, spatial_size) - - embed_dim_spatial = 3 * embed_dim // 4 - embed_dim_temporal = embed_dim // 4 - - # 1. Spatial - grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) - pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) - - # 2. Temporal - grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale - pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) - - # 3. Concat - pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] - pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] - - pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] - pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] - - pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] - return pos_embed - - -def get_2d_sincos_pos_embed( - embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 -): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_size = (grid_size, grid_size) - - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb class CogVideoXPatchEmbed(nn.Module): def __init__( diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index a563b73..f2fb927 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -26,9 +26,8 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -#from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.loaders import CogVideoXLoraLoaderMixin -from .embeddings import get_3d_rotary_pos_embed from .custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel