diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py index 5df9a3c..b39831f 100644 --- a/videosys/cogvideox_transformer_3d.py +++ b/videosys/cogvideox_transformer_3d.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.attention import Attention, FeedForward -from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, CogVideoXPatchEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import is_torch_version @@ -24,7 +24,7 @@ from torch import nn from .core.pab_mgr import enable_pab, if_broadcast_spatial from .modules.embeddings import apply_rotary_emb -from .modules.embeddings import CogVideoXPatchEmbed +#from .modules.embeddings import CogVideoXPatchEmbed from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -407,6 +407,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -417,7 +418,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames # 1. Patch embedding - self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + text_embed_dim=text_embed_dim, + bias=True, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + ) self.embedding_dropout = nn.Dropout(dropout) # 2. 3D positional embeddings @@ -590,7 +606,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): # 6. Unpatchify p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) #if self.parallel_manager.cp_size > 1: diff --git a/videosys/core/comm.py b/videosys/core/comm.py deleted file mode 100644 index 3b44402..0000000 --- a/videosys/core/comm.py +++ /dev/null @@ -1,406 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor -from torch.distributed import ProcessGroup - -# ====================================================== -# Model -# ====================================================== - - -def model_sharding(model: torch.nn.Module): - global_rank = dist.get_rank() - world_size = dist.get_world_size() - for _, param in model.named_parameters(): - padding_size = (world_size - param.numel() % world_size) % world_size - if padding_size > 0: - padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) - else: - padding_param = param.data.view(-1) - splited_params = padding_param.split(padding_param.numel() // world_size) - splited_params = splited_params[global_rank] - param.data = splited_params - - -# ====================================================== -# AllGather & ReduceScatter -# ====================================================== - - -class AsyncAllGatherForTwo(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - weight: Tensor, - bias: Tensor, - sp_rank: int, - sp_size: int, - group: Optional[ProcessGroup] = None, - ) -> Tuple[Tensor, Any]: - """ - Returns: - outputs: Tensor - handle: Optional[Work], if overlap is True - """ - from torch.distributed._functional_collectives import all_gather_tensor - - ctx.group = group - ctx.sp_rank = sp_rank - ctx.sp_size = sp_size - - # all gather inputs - all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group) - # compute local qkv - local_qkv = F.linear(inputs, weight, bias).unsqueeze(0) - - # remote compute - remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1]) - # compute remote qkv - remote_qkv = F.linear(remote_inputs, weight, bias) - - # concat local and remote qkv - if sp_rank == 0: - qkv = torch.cat([local_qkv, remote_qkv], dim=0) - else: - qkv = torch.cat([remote_qkv, local_qkv], dim=0) - qkv = rearrange(qkv, "sp b n c -> b (sp n) c") - - ctx.save_for_backward(inputs, weight, remote_inputs) - return qkv - - @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: - from torch.distributed._functional_collectives import reduce_scatter_tensor - - group = ctx.group - sp_rank = ctx.sp_rank - sp_size = ctx.sp_size - inputs, weight, remote_inputs = ctx.saved_tensors - - # split qkv_grad - qkv_grad = grad_outputs[0] - qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size) - qkv_grad = torch.chunk(qkv_grad, 2, dim=0) - if sp_rank == 0: - local_qkv_grad, remote_qkv_grad = qkv_grad - else: - remote_qkv_grad, local_qkv_grad = qkv_grad - - # compute remote grad - remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0) - weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0) - bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0) - - # launch async reduce scatter - remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad) - if sp_rank == 0: - remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0) - else: - remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0) - remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group) - - # compute local grad and wait for reduce scatter - local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0) - weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0) - bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0) - - # sum remote and local grad - inputs_grad = remote_inputs_grad + local_input_grad - return inputs_grad, weight_grad, bias_grad, None, None, None - - -class AllGather(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - group: Optional[ProcessGroup] = None, - overlap: bool = False, - ) -> Tuple[Tensor, Any]: - """ - Returns: - outputs: Tensor - handle: Optional[Work], if overlap is True - """ - assert ctx is not None or not overlap - - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.unsqueeze(0), None - - buffer_shape = (comm_size,) + inputs.shape - outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - if not overlap: - dist.all_gather(buffer_list, inputs, group=group) - return outputs, None - else: - handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) - return outputs, handle - - @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: - return ( - ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], - None, - None, - ) - - -class ReduceScatter(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - group: ProcessGroup, - overlap: bool = False, - ) -> Tuple[Tensor, Any]: - """ - Returns: - outputs: Tensor - handle: Optional[Work], if overlap is True - """ - assert ctx is not None or not overlap - - if ctx is not None: - ctx.comm_grp = group - - comm_size = dist.get_world_size(group) - if comm_size == 1: - return inputs.squeeze(0), None - - if not inputs.is_contiguous(): - inputs = inputs.contiguous() - - output_shape = inputs.shape[1:] - outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) - buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - if not overlap: - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs, None - else: - handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) - return outputs, handle - - @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: - # TODO: support async backward - return ( - AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], - None, - None, - ) - - -# ====================================================== -# AlltoAll -# ====================================================== - - -def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=gather_dim).contiguous() - - -class _AllToAll(torch.autograd.Function): - """All-to-all communication. - - Args: - input_: input matrix - process_group: communication group - scatter_dim: scatter dimension - gather_dim: gather dimension - """ - - @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim): - ctx.process_group = process_group - ctx.scatter_dim = scatter_dim - ctx.gather_dim = gather_dim - world_size = dist.get_world_size(process_group) - - return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) - - @staticmethod - def backward(ctx, *grad_output): - process_group = ctx.process_group - scatter_dim = ctx.gather_dim - gather_dim = ctx.scatter_dim - return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - return (return_grad, None, None, None) - - -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) - - -# ====================================================== -# Sequence Gather & Split -# ====================================================== - - -def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): - # skip if only one rank involved - world_size = dist.get_world_size(pg) - rank = dist.get_rank(pg) - if world_size == 1: - return input_ - - if pad > 0: - pad_size = list(input_.shape) - pad_size[dim] = pad - input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) - - dim_size = input_.size(dim) - assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})" - - tensor_list = torch.split(input_, dim_size // world_size, dim=dim) - output = tensor_list[rank].contiguous() - return output - - -def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): - # skip if only one rank involved - input_ = input_.contiguous() - world_size = dist.get_world_size(pg) - dist.get_rank(pg) - - if world_size == 1: - return input_ - - # all gather - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - assert input_.device.type == "cuda" - torch.distributed.all_gather(tensor_list, input_, group=pg) - - # concat - output = torch.cat(tensor_list, dim=dim) - - if pad > 0: - output = output.narrow(dim, 0, output.size(dim) - pad) - - return output - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """ - Gather the input sequence. - - Args: - input_: input matrix. - process_group: process group. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _gather_sequence_func(input_) - - @staticmethod - def forward(ctx, input_, process_group, dim, grad_scale, pad): - ctx.process_group = process_group - ctx.dim = dim - ctx.grad_scale = grad_scale - ctx.pad = pad - return _gather_sequence_func(input_, process_group, dim, pad) - - @staticmethod - def backward(ctx, grad_output): - if ctx.grad_scale == "up": - grad_output = grad_output * dist.get_world_size(ctx.process_group) - elif ctx.grad_scale == "down": - grad_output = grad_output / dist.get_world_size(ctx.process_group) - - return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split sequence. - - Args: - input_: input matrix. - process_group: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(graph, input_): - return _split_sequence_func(input_) - - @staticmethod - def forward(ctx, input_, process_group, dim, grad_scale, pad): - ctx.process_group = process_group - ctx.dim = dim - ctx.grad_scale = grad_scale - ctx.pad = pad - return _split_sequence_func(input_, process_group, dim, pad) - - @staticmethod - def backward(ctx, grad_output): - if ctx.grad_scale == "up": - grad_output = grad_output * dist.get_world_size(ctx.process_group) - elif ctx.grad_scale == "down": - grad_output = grad_output / dist.get_world_size(ctx.process_group) - return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None - - -def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): - return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad) - - -def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): - return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad) - - -# ============================== -# Pad -# ============================== - -PAD_DICT = {} - - -def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup): - sp_size = dist.get_world_size(parallel_group) - pad = (sp_size - (dim_size % sp_size)) % sp_size - global PAD_DICT - PAD_DICT[name] = pad - - -def get_pad(name) -> int: - return PAD_DICT[name] - - -def all_to_all_with_pad( - input_: torch.Tensor, - process_group: dist.ProcessGroup, - scatter_dim: int = 2, - gather_dim: int = 1, - scatter_pad: int = 0, - gather_pad: int = 0, -): - if scatter_pad > 0: - pad_shape = list(input_.shape) - pad_shape[scatter_dim] = scatter_pad - pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) - input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) - - assert ( - input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 - ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" - input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) - - if gather_pad > 0: - input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) - - return input_ diff --git a/videosys/core/pipeline.py b/videosys/core/pipeline.py index 75b79d3..3244749 100644 --- a/videosys/core/pipeline.py +++ b/videosys/core/pipeline.py @@ -1,11 +1,8 @@ import inspect from abc import abstractmethod -from dataclasses import dataclass import torch from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.utils import BaseOutput - class VideoSysPipeline(DiffusionPipeline): def __init__(self): @@ -45,8 +42,3 @@ class VideoSysPipeline(DiffusionPipeline): optional_parameters.remove(name) return expected_modules, optional_parameters - - -@dataclass -class VideoSysPipelineOutput(BaseOutput): - video: torch.Tensor diff --git a/videosys/modules/attentions.py b/videosys/modules/attentions.py deleted file mode 100644 index 8e2c20c..0000000 --- a/videosys/modules/attentions.py +++ /dev/null @@ -1,205 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, List, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint - -from videosys.models.modules.normalization import LlamaRMSNorm - - -class OpenSoraAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - norm_layer: nn.Module = LlamaRMSNorm, - enable_flash_attn: bool = False, - rope=None, - qk_norm_legacy: bool = False, - ) -> None: - super().__init__() - assert dim % num_heads == 0, "dim should be divisible by num_heads" - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim**-0.5 - self.enable_flash_attn = enable_flash_attn - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.qk_norm_legacy = qk_norm_legacy - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.rope = False - if rope is not None: - self.rope = True - self.rotary_emb = rope - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - # flash attn is not memory efficient for small sequences, this is empirical - enable_flash_attn = self.enable_flash_attn and (N > B) - qkv = self.qkv(x) - qkv_shape = (B, N, 3, self.num_heads, self.head_dim) - - qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - if self.qk_norm_legacy: - # WARNING: this may be a bug - if self.rope: - q = self.rotary_emb(q) - k = self.rotary_emb(k) - q, k = self.q_norm(q), self.k_norm(k) - else: - q, k = self.q_norm(q), self.k_norm(k) - if self.rope: - q = self.rotary_emb(q) - k = self.rotary_emb(k) - - if enable_flash_attn: - from flash_attn import flash_attn_func - - # (B, #heads, N, #dim) -> (B, N, #heads, #dim) - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - x = flash_attn_func( - q, - k, - v, - dropout_p=self.attn_drop.p if self.training else 0.0, - softmax_scale=self.scale, - ) - else: - x = F.scaled_dot_product_attention(q, k, v) - - x_output_shape = (B, N, C) - if not enable_flash_attn: - x = x.transpose(1, 2) - x = x.reshape(x_output_shape) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class OpenSoraMultiHeadCrossAttention(nn.Module): - def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False): - super(OpenSoraMultiHeadCrossAttention, self).__init__() - assert d_model % num_heads == 0, "d_model must be divisible by num_heads" - - self.d_model = d_model - self.num_heads = num_heads - self.head_dim = d_model // num_heads - - self.q_linear = nn.Linear(d_model, d_model) - self.kv_linear = nn.Linear(d_model, d_model * 2) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(d_model, d_model) - self.proj_drop = nn.Dropout(proj_drop) - self.enable_flash_attn = enable_flash_attn - - def forward(self, x, cond, mask=None): - # query/value: img tokens; key: condition; mask: if padding tokens - B, N, C = x.shape - - q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) - kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) - k, v = kv.unbind(2) - - if self.enable_flash_attn: - x = self.flash_attn_impl(q, k, v, mask, B, N, C) - else: - x = self.torch_impl(q, k, v, mask, B, N, C) - - x = self.proj(x) - x = self.proj_drop(x) - return x - - def flash_attn_impl(self, q, k, v, mask, B, N, C): - from flash_attn import flash_attn_varlen_func - - q_seqinfo = _SeqLenInfo.from_seqlens([N] * B) - k_seqinfo = _SeqLenInfo.from_seqlens(mask) - - x = flash_attn_varlen_func( - q.view(-1, self.num_heads, self.head_dim), - k.view(-1, self.num_heads, self.head_dim), - v.view(-1, self.num_heads, self.head_dim), - cu_seqlens_q=q_seqinfo.seqstart.cuda(), - cu_seqlens_k=k_seqinfo.seqstart.cuda(), - max_seqlen_q=q_seqinfo.max_seqlen, - max_seqlen_k=k_seqinfo.max_seqlen, - dropout_p=self.attn_drop.p if self.training else 0.0, - ) - x = x.view(B, N, C) - return x - - def torch_impl(self, q, k, v, mask, B, N, C): - q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) - - attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device) - for i, m in enumerate(mask): - attn_mask[i, :, :, :m] = -1e9 - - out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - x = out.transpose(1, 2).contiguous().view(B, N, C) - return x - - -@dataclass -class _SeqLenInfo: - """ - from xformers - - (Internal) Represents the division of a dimension into blocks. - For example, to represents a dimension of length 7 divided into - three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. - The members will be: - max_seqlen: 3 - min_seqlen: 2 - seqstart_py: [0, 2, 5, 7] - seqstart: torch.IntTensor([0, 2, 5, 7]) - """ - - seqstart: torch.Tensor - max_seqlen: int - min_seqlen: int - seqstart_py: List[int] - - def to(self, device: torch.device) -> None: - self.seqstart = self.seqstart.to(device, non_blocking=True) - - def intervals(self) -> Iterable[Tuple[int, int]]: - yield from zip(self.seqstart_py, self.seqstart_py[1:]) - - @classmethod - def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": - """ - Input tensors are assumed to be in shape [B, M, *] - """ - assert not isinstance(seqlens, torch.Tensor) - seqstart_py = [0] - max_seqlen = -1 - min_seqlen = -1 - for seqlen in seqlens: - min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen - max_seqlen = max(max_seqlen, seqlen) - seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) - seqstart = torch.tensor(seqstart_py, dtype=torch.int32) - return cls( - max_seqlen=max_seqlen, - min_seqlen=min_seqlen, - seqstart=seqstart, - seqstart_py=seqstart_py, - ) diff --git a/videosys/modules/normalization.py b/videosys/modules/normalization.py index 7985e56..216d0cc 100644 --- a/videosys/modules/normalization.py +++ b/videosys/modules/normalization.py @@ -4,23 +4,6 @@ import torch import torch.nn as nn -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - class CogVideoXLayerNormZero(nn.Module): def __init__( self,