mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
remove timm as unnecessary dependency
This commit is contained in:
parent
81adb0220b
commit
513a2ab090
@ -1,3 +1,4 @@
|
||||
huggingface_hub
|
||||
diffusers>=0.30.3
|
||||
accelerate>=0.33.0
|
||||
einops
|
||||
@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@ -8,8 +7,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
@ -179,107 +176,6 @@ class SizeEmbedder(TimestepEmbedder):
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class OpenSoraCaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
uncond_prob,
|
||||
act_layer=nn.GELU(approximate="tanh"),
|
||||
token_num=120,
|
||||
):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels,
|
||||
hidden_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
act_layer=act_layer,
|
||||
drop=0,
|
||||
)
|
||||
self.register_buffer(
|
||||
"y_embedding",
|
||||
torch.randn(token_num, in_channels) / in_channels**0.5,
|
||||
)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class OpenSoraPositionEmbedding2D(nn.Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert dim % 4 == 0, "dim must be divisible by 4"
|
||||
half_dim = dim // 2
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def _get_sin_cos_emb(self, t: torch.Tensor):
|
||||
out = torch.einsum("i,d->id", t, self.inv_freq)
|
||||
emb_cos = torch.cos(out)
|
||||
emb_sin = torch.sin(out)
|
||||
return torch.cat((emb_sin, emb_cos), dim=-1)
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def _get_cached_emb(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: float = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
):
|
||||
grid_h = torch.arange(h, device=device) / scale
|
||||
grid_w = torch.arange(w, device=device) / scale
|
||||
if base_size is not None:
|
||||
grid_h *= base_size / h
|
||||
grid_w *= base_size / w
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
grid_w,
|
||||
grid_h,
|
||||
indexing="ij",
|
||||
) # here w goes first
|
||||
grid_h = grid_h.t().reshape(-1)
|
||||
grid_w = grid_w.t().reshape(-1)
|
||||
emb_h = self._get_sin_cos_emb(grid_h)
|
||||
emb_w = self._get_sin_cos_emb(grid_w)
|
||||
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: Optional[float] = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user