diff --git a/requirements.txt b/requirements.txt index 4f00dac..0dbc732 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ huggingface_hub diffusers>=0.30.3 -accelerate>=0.33.0 \ No newline at end of file +accelerate>=0.33.0 +einops \ No newline at end of file diff --git a/videosys/modules/embeddings.py b/videosys/modules/embeddings.py index 13dd629..04eba82 100644 --- a/videosys/modules/embeddings.py +++ b/videosys/modules/embeddings.py @@ -1,4 +1,3 @@ -import functools import math from typing import Optional, Tuple, Union @@ -8,8 +7,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange -from timm.models.vision_transformer import Mlp - class CogVideoXPatchEmbed(nn.Module): def __init__( @@ -179,107 +176,6 @@ class SizeEmbedder(TimestepEmbedder): def dtype(self): return next(self.parameters()).dtype - -class OpenSoraCaptionEmbedder(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - """ - - def __init__( - self, - in_channels, - hidden_size, - uncond_prob, - act_layer=nn.GELU(approximate="tanh"), - token_num=120, - ): - super().__init__() - self.y_proj = Mlp( - in_features=in_channels, - hidden_features=hidden_size, - out_features=hidden_size, - act_layer=act_layer, - drop=0, - ) - self.register_buffer( - "y_embedding", - torch.randn(token_num, in_channels) / in_channels**0.5, - ) - self.uncond_prob = uncond_prob - - def token_drop(self, caption, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob - else: - drop_ids = force_drop_ids == 1 - caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) - return caption - - def forward(self, caption, train, force_drop_ids=None): - if train: - assert caption.shape[2:] == self.y_embedding.shape - use_dropout = self.uncond_prob > 0 - if (train and use_dropout) or (force_drop_ids is not None): - caption = self.token_drop(caption, force_drop_ids) - caption = self.y_proj(caption) - return caption - - -class OpenSoraPositionEmbedding2D(nn.Module): - def __init__(self, dim: int) -> None: - super().__init__() - self.dim = dim - assert dim % 4 == 0, "dim must be divisible by 4" - half_dim = dim // 2 - inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def _get_sin_cos_emb(self, t: torch.Tensor): - out = torch.einsum("i,d->id", t, self.inv_freq) - emb_cos = torch.cos(out) - emb_sin = torch.sin(out) - return torch.cat((emb_sin, emb_cos), dim=-1) - - @functools.lru_cache(maxsize=512) - def _get_cached_emb( - self, - device: torch.device, - dtype: torch.dtype, - h: int, - w: int, - scale: float = 1.0, - base_size: Optional[int] = None, - ): - grid_h = torch.arange(h, device=device) / scale - grid_w = torch.arange(w, device=device) / scale - if base_size is not None: - grid_h *= base_size / h - grid_w *= base_size / w - grid_h, grid_w = torch.meshgrid( - grid_w, - grid_h, - indexing="ij", - ) # here w goes first - grid_h = grid_h.t().reshape(-1) - grid_w = grid_w.t().reshape(-1) - emb_h = self._get_sin_cos_emb(grid_h) - emb_w = self._get_sin_cos_emb(grid_w) - return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype) - - def forward( - self, - x: torch.Tensor, - h: int, - w: int, - scale: Optional[float] = 1.0, - base_size: Optional[int] = None, - ) -> torch.Tensor: - return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size) - - def get_3d_rotary_pos_embed( embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: