remove timm as unnecessary dependency

This commit is contained in:
kijai 2024-09-24 19:29:39 +03:00
parent 81adb0220b
commit 513a2ab090
2 changed files with 2 additions and 105 deletions

View File

@ -1,3 +1,4 @@
huggingface_hub huggingface_hub
diffusers>=0.30.3 diffusers>=0.30.3
accelerate>=0.33.0 accelerate>=0.33.0
einops

View File

@ -1,4 +1,3 @@
import functools
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@ -8,8 +7,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from einops import rearrange from einops import rearrange
from timm.models.vision_transformer import Mlp
class CogVideoXPatchEmbed(nn.Module): class CogVideoXPatchEmbed(nn.Module):
def __init__( def __init__(
@ -179,107 +176,6 @@ class SizeEmbedder(TimestepEmbedder):
def dtype(self): def dtype(self):
return next(self.parameters()).dtype return next(self.parameters()).dtype
class OpenSoraCaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(
self,
in_channels,
hidden_size,
uncond_prob,
act_layer=nn.GELU(approximate="tanh"),
token_num=120,
):
super().__init__()
self.y_proj = Mlp(
in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=act_layer,
drop=0,
)
self.register_buffer(
"y_embedding",
torch.randn(token_num, in_channels) / in_channels**0.5,
)
self.uncond_prob = uncond_prob
def token_drop(self, caption, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None):
if train:
assert caption.shape[2:] == self.y_embedding.shape
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
class OpenSoraPositionEmbedding2D(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
assert dim % 4 == 0, "dim must be divisible by 4"
half_dim = dim // 2
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _get_sin_cos_emb(self, t: torch.Tensor):
out = torch.einsum("i,d->id", t, self.inv_freq)
emb_cos = torch.cos(out)
emb_sin = torch.sin(out)
return torch.cat((emb_sin, emb_cos), dim=-1)
@functools.lru_cache(maxsize=512)
def _get_cached_emb(
self,
device: torch.device,
dtype: torch.dtype,
h: int,
w: int,
scale: float = 1.0,
base_size: Optional[int] = None,
):
grid_h = torch.arange(h, device=device) / scale
grid_w = torch.arange(w, device=device) / scale
if base_size is not None:
grid_h *= base_size / h
grid_w *= base_size / w
grid_h, grid_w = torch.meshgrid(
grid_w,
grid_h,
indexing="ij",
) # here w goes first
grid_h = grid_h.t().reshape(-1)
grid_w = grid_w.t().reshape(-1)
emb_h = self._get_sin_cos_emb(grid_h)
emb_w = self._get_sin_cos_emb(grid_w)
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
def forward(
self,
x: torch.Tensor,
h: int,
w: int,
scale: Optional[float] = 1.0,
base_size: Optional[int] = None,
) -> torch.Tensor:
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
def get_3d_rotary_pos_embed( def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: