mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class CogVideoXLayerNormZero(nn.Module):
|
|
def __init__(
|
|
self,
|
|
conditioning_dim: int,
|
|
embedding_dim: int,
|
|
elementwise_affine: bool = True,
|
|
eps: float = 1e-5,
|
|
bias: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
|
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
r"""
|
|
Norm layer modified to incorporate timestep embeddings.
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
|
output_dim (`int`, *optional*):
|
|
norm_elementwise_affine (`bool`, defaults to `False):
|
|
norm_eps (`bool`, defaults to `False`):
|
|
chunk_dim (`int`, defaults to `0`):
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
num_embeddings: Optional[int] = None,
|
|
output_dim: Optional[int] = None,
|
|
norm_elementwise_affine: bool = False,
|
|
norm_eps: float = 1e-5,
|
|
chunk_dim: int = 0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.chunk_dim = chunk_dim
|
|
output_dim = output_dim or embedding_dim * 2
|
|
|
|
if num_embeddings is not None:
|
|
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
|
else:
|
|
self.emb = None
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
if self.emb is not None:
|
|
temb = self.emb(timestep)
|
|
|
|
temb = self.linear(self.silu(temb))
|
|
|
|
if self.chunk_dim == 1:
|
|
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
|
# other if-branch. This branch is specific to CogVideoX for now.
|
|
shift, scale = temb.chunk(2, dim=1)
|
|
shift = shift[:, None, :]
|
|
scale = scale[:, None, :]
|
|
else:
|
|
scale, shift = temb.chunk(2, dim=0)
|
|
|
|
x = self.norm(x) * (1 + scale) + shift
|
|
return x
|