from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class CogVideoXLayerNormZero(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`, *optional*): The size of the embeddings dictionary. output_dim (`int`, *optional*): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): """ def __init__( self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, ): super().__init__() self.chunk_dim = chunk_dim output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: self.emb = nn.Embedding(num_embeddings, embedding_dim) else: self.emb = None self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None ) -> torch.Tensor: if self.emb is not None: temb = self.emb(timestep) temb = self.linear(self.silu(temb)) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the # other if-branch. This branch is specific to CogVideoX for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] else: scale, shift = temb.chunk(2, dim=0) x = self.norm(x) * (1 + scale) + shift return x class VchitectSpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. Args: f_channels (`int`): The number of channels for input to group normalization layer, and output of the spatial norm layer. zq_channels (`int`): The number of channels for the quantized vector as described in the paper. """ def __init__( self, f_channels: int, zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f