mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
initial experimental PAB support (only normal text2vid for now)
This commit is contained in:
parent
3c8e939f8e
commit
93edf24631
102
nodes.py
102
nodes.py
@ -15,7 +15,6 @@ from diffusers.schedulers import (
|
||||
HeunDiscreteScheduler,
|
||||
SASolverScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDIMInverseScheduler
|
||||
)
|
||||
|
||||
scheduler_mapping = {
|
||||
@ -50,6 +49,88 @@ log = logging.getLogger(__name__)
|
||||
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
spatial_broadcast: bool = False,
|
||||
spatial_threshold: list = None,
|
||||
spatial_range: int = None,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = None,
|
||||
temporal_range: int = None,
|
||||
mlp_broadcast: bool = False,
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
self.cross_range = cross_range
|
||||
|
||||
self.spatial_broadcast = spatial_broadcast
|
||||
self.spatial_threshold = spatial_threshold
|
||||
self.spatial_range = spatial_range
|
||||
|
||||
self.temporal_broadcast = temporal_broadcast
|
||||
self.temporal_threshold = temporal_threshold
|
||||
self.temporal_range = temporal_range
|
||||
|
||||
self.mlp_broadcast = mlp_broadcast
|
||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
||||
self.mlp_temporal_outputs = {}
|
||||
self.mlp_spatial_outputs = {}
|
||||
|
||||
class CogVideoXPABConfig(PABConfig):
|
||||
def __init__(
|
||||
self,
|
||||
steps: int = 50,
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=spatial_threshold,
|
||||
spatial_range=spatial_range,
|
||||
)
|
||||
|
||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
|
||||
class CogVideoPABConfig:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
|
||||
"pab_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"pab_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"pab_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PAB_CONFIG",)
|
||||
RETURN_NAMES = ("pab_config", )
|
||||
FUNCTION = "config"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def config(self, spatial_broadcast, pab_threshold_start, pab_threshold_end, pab_range, steps):
|
||||
|
||||
pab_config = CogVideoXPABConfig(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
spatial_threshold=[pab_threshold_end, pab_threshold_start],
|
||||
spatial_range=pab_range
|
||||
)
|
||||
|
||||
return (pab_config, )
|
||||
|
||||
class DownloadAndLoadCogVideoModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -74,6 +155,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs"}),
|
||||
"compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}),
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +164,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False):
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
@ -129,7 +211,10 @@ class DownloadAndLoadCogVideoModel:
|
||||
if "Fun" in model:
|
||||
transformer = CogVideoXTransformer3DModelFun.from_pretrained(base_path, subfolder="transformer")
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelPAB.from_pretrained(base_path, subfolder="transformer")
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer")
|
||||
|
||||
transformer = transformer.to(dtype).to(offload_device)
|
||||
|
||||
@ -151,7 +236,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
with open(scheduler_path) as f:
|
||||
scheduler_config = json.load(f)
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
# VAE
|
||||
if "Fun" in model:
|
||||
@ -159,7 +244,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
pipe = CogVideoX_Fun_Pipeline_Inpaint(vae, transformer, scheduler)
|
||||
else:
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@ -729,7 +814,6 @@ class CogVideoDecode:
|
||||
|
||||
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
||||
video = video[0].permute(0, 2, 3, 1).cpu().float()
|
||||
print(video.min(), video.max())
|
||||
|
||||
return (video,)
|
||||
|
||||
@ -979,7 +1063,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoXFunSampler": CogVideoXFunSampler,
|
||||
"CogVideoXFunVid2VidSampler": CogVideoXFunVid2VidSampler,
|
||||
"CogVideoTextEncodeCombine": CogVideoTextEncodeCombine,
|
||||
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel
|
||||
"DownloadAndLoadCogVideoGGUFModel": DownloadAndLoadCogVideoGGUFModel,
|
||||
"CogVideoPABConfig": CogVideoPABConfig
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||
@ -991,5 +1076,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoXFunSampler": "CogVideoXFun Sampler",
|
||||
"CogVideoXFunVid2VidSampler": "CogVideoXFun Vid2Vid Sampler",
|
||||
"CogVideoTextEncodeCombine": "CogVideo TextEncode Combine",
|
||||
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model"
|
||||
"DownloadAndLoadCogVideoGGUFModel": "(Down)load CogVideo GGUF Model",
|
||||
"CogVideoPABConfig": "CogVideo PABConfig"
|
||||
}
|
||||
|
||||
@ -32,6 +32,10 @@ from comfy.utils import ProgressBar
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
from .videosys.core.pipeline import VideoSysPipeline
|
||||
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
|
||||
from .videosys.core.pab_mgr import set_pab_manager
|
||||
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
@ -108,7 +112,7 @@ def retrieve_timesteps(
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
class CogVideoXPipeline(DiffusionPipeline):
|
||||
class CogVideoXPipeline(VideoSysPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using CogVideoX.
|
||||
|
||||
@ -137,9 +141,10 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
transformer: Union[CogVideoXTransformer3DModel, CogVideoXTransformer3DModelPAB],
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
original_mask = None
|
||||
original_mask = None,
|
||||
pab_config = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -155,6 +160,10 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
self.original_mask = original_mask
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
if pab_config is not None:
|
||||
print(pab_config)
|
||||
set_pab_manager(pab_config)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
|
||||
):
|
||||
|
||||
604
videosys/cogvideox_transformer_3d.py
Normal file
604
videosys/cogvideox_transformer_3d.py
Normal file
@ -0,0 +1,604 @@
|
||||
# Adapted from CogVideo
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# CogVideo: https://github.com/THUDM/CogVideo
|
||||
# diffusers: https://github.com/huggingface/diffusers
|
||||
# --------------------------------------------------------
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
from torch import nn
|
||||
|
||||
from .core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
|
||||
from .core.pab_mgr import enable_pab, if_broadcast_spatial
|
||||
#from .core.parallel_mgr import ParallelManager
|
||||
from .modules.embeddings import apply_rotary_emb
|
||||
from .utils.utils import batch_func
|
||||
|
||||
from .modules.embeddings import CogVideoXPatchEmbed
|
||||
from .modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# if attn.parallel_manager.sp_size > 1:
|
||||
# assert (
|
||||
# attn.heads % attn.parallel_manager.sp_size == 0
|
||||
# ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
|
||||
# attn_heads = attn.heads // attn.parallel_manager.sp_size
|
||||
# query, key, value = map(
|
||||
# lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
|
||||
# [query, key, value],
|
||||
# )
|
||||
|
||||
attn_heads = attn.heads
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn_heads
|
||||
|
||||
query = query.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
emb_len = image_rotary_emb[0].shape[0]
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : emb_len + text_seq_length], image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
|
||||
#if attn.parallel_manager.sp_size > 1:
|
||||
# hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# parallel
|
||||
#self.attn1.parallel_manager = None
|
||||
|
||||
# 2. Feed Forward
|
||||
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# pab
|
||||
self.attn_count = 0
|
||||
self.last_attn = None
|
||||
self.block_idx = block_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep=None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
if enable_pab():
|
||||
broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
|
||||
if enable_pab() and broadcast_attn:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.last_attn
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if enable_pab():
|
||||
self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
time_embed_dim: int = 512,
|
||||
text_embed_dim: int = 4096,
|
||||
num_layers: int = 30,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
patch_size: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# 3. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
# 4. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogVideoXBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output blocks
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# parallel
|
||||
#self.parallel_manager = None
|
||||
|
||||
# def enable_parallel(self, dp_size, sp_size, enable_cp):
|
||||
# # update cfg parallel
|
||||
# if enable_cp and sp_size % 2 == 0:
|
||||
# sp_size = sp_size // 2
|
||||
# cp_size = 2
|
||||
# else:
|
||||
# cp_size = 1
|
||||
|
||||
# self.parallel_manager: ParallelManager = ParallelManager(dp_size, cp_size, sp_size)
|
||||
|
||||
# for _, module in self.named_modules():
|
||||
# if hasattr(module, "parallel_manager"):
|
||||
# module.parallel_manager = self.parallel_manager
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# if self.parallel_manager.cp_size > 1:
|
||||
# (
|
||||
# hidden_states,
|
||||
# encoder_hidden_states,
|
||||
# timestep,
|
||||
# timestep_cond,
|
||||
# image_rotary_emb,
|
||||
# ) = batch_func(
|
||||
# partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
||||
# hidden_states,
|
||||
# encoder_hidden_states,
|
||||
# timestep,
|
||||
# timestep_cond,
|
||||
# image_rotary_emb,
|
||||
# )
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# if self.parallel_manager.sp_size > 1:
|
||||
# set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
|
||||
# hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if enable_pab() else None,
|
||||
)
|
||||
|
||||
#if self.parallel_manager.sp_size > 1:
|
||||
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
#if self.parallel_manager.cp_size > 1:
|
||||
# output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
0
videosys/core/__init__.py
Normal file
0
videosys/core/__init__.py
Normal file
406
videosys/core/comm.py
Normal file
406
videosys/core/comm.py
Normal file
@ -0,0 +1,406 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# ======================================================
|
||||
# Model
|
||||
# ======================================================
|
||||
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
for _, param in model.named_parameters():
|
||||
padding_size = (world_size - param.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // world_size)
|
||||
splited_params = splited_params[global_rank]
|
||||
param.data = splited_params
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AllGather & ReduceScatter
|
||||
# ======================================================
|
||||
|
||||
|
||||
class AsyncAllGatherForTwo(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
sp_rank: int,
|
||||
sp_size: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
from torch.distributed._functional_collectives import all_gather_tensor
|
||||
|
||||
ctx.group = group
|
||||
ctx.sp_rank = sp_rank
|
||||
ctx.sp_size = sp_size
|
||||
|
||||
# all gather inputs
|
||||
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
|
||||
# compute local qkv
|
||||
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
|
||||
|
||||
# remote compute
|
||||
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
|
||||
# compute remote qkv
|
||||
remote_qkv = F.linear(remote_inputs, weight, bias)
|
||||
|
||||
# concat local and remote qkv
|
||||
if sp_rank == 0:
|
||||
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
|
||||
else:
|
||||
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
|
||||
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
|
||||
|
||||
ctx.save_for_backward(inputs, weight, remote_inputs)
|
||||
return qkv
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
from torch.distributed._functional_collectives import reduce_scatter_tensor
|
||||
|
||||
group = ctx.group
|
||||
sp_rank = ctx.sp_rank
|
||||
sp_size = ctx.sp_size
|
||||
inputs, weight, remote_inputs = ctx.saved_tensors
|
||||
|
||||
# split qkv_grad
|
||||
qkv_grad = grad_outputs[0]
|
||||
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
|
||||
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
|
||||
if sp_rank == 0:
|
||||
local_qkv_grad, remote_qkv_grad = qkv_grad
|
||||
else:
|
||||
remote_qkv_grad, local_qkv_grad = qkv_grad
|
||||
|
||||
# compute remote grad
|
||||
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
|
||||
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
|
||||
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# launch async reduce scatter
|
||||
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
|
||||
if sp_rank == 0:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
|
||||
else:
|
||||
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
|
||||
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
|
||||
|
||||
# compute local grad and wait for reduce scatter
|
||||
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
|
||||
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
|
||||
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
|
||||
|
||||
# sum remote and local grad
|
||||
inputs_grad = remote_inputs_grad + local_input_grad
|
||||
return inputs_grad, weight_grad, bias_grad, None, None, None
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0), None
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
return (
|
||||
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
inputs: Tensor,
|
||||
group: ProcessGroup,
|
||||
overlap: bool = False,
|
||||
) -> Tuple[Tensor, Any]:
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
assert ctx is not None or not overlap
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0), None
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
if not overlap:
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs, None
|
||||
else:
|
||||
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||||
return outputs, handle
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||||
# TODO: support async backward
|
||||
return (
|
||||
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# AlltoAll
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""All-to-all communication.
|
||||
|
||||
Args:
|
||||
input_: input matrix
|
||||
process_group: communication group
|
||||
scatter_dim: scatter dimension
|
||||
gather_dim: gather dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||
ctx.process_group = process_group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
process_group = ctx.process_group
|
||||
scatter_dim = ctx.gather_dim
|
||||
gather_dim = ctx.scatter_dim
|
||||
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||
return (return_grad, None, None, None)
|
||||
|
||||
|
||||
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Sequence Gather & Split
|
||||
# ======================================================
|
||||
|
||||
|
||||
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if pad > 0:
|
||||
pad_size = list(input_.shape)
|
||||
pad_size[dim] = pad
|
||||
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
|
||||
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
output = tensor_list[rank].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
||||
# skip if only one rank involved
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(pg)
|
||||
dist.get_rank(pg)
|
||||
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
assert input_.device.type == "cuda"
|
||||
torch.distributed.all_gather(tensor_list, input_, group=pg)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim)
|
||||
|
||||
if pad > 0:
|
||||
output = output.narrow(dim, 0, output.size(dim) - pad)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""
|
||||
Gather the input sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: process group.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _gather_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
|
||||
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split sequence.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_sequence_func(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.pad = pad
|
||||
return _split_sequence_func(input_, process_group, dim, pad)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.grad_scale == "up":
|
||||
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
||||
elif ctx.grad_scale == "down":
|
||||
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
||||
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
|
||||
|
||||
|
||||
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Pad
|
||||
# ==============================
|
||||
|
||||
PAD_DICT = {}
|
||||
|
||||
|
||||
def set_pad(name: str, dim_size: int, parallel_group: dist.ProcessGroup):
|
||||
sp_size = dist.get_world_size(parallel_group)
|
||||
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
||||
global PAD_DICT
|
||||
PAD_DICT[name] = pad
|
||||
|
||||
|
||||
def get_pad(name) -> int:
|
||||
return PAD_DICT[name]
|
||||
|
||||
|
||||
def all_to_all_with_pad(
|
||||
input_: torch.Tensor,
|
||||
process_group: dist.ProcessGroup,
|
||||
scatter_dim: int = 2,
|
||||
gather_dim: int = 1,
|
||||
scatter_pad: int = 0,
|
||||
gather_pad: int = 0,
|
||||
):
|
||||
if scatter_pad > 0:
|
||||
pad_shape = list(input_.shape)
|
||||
pad_shape[scatter_dim] = scatter_pad
|
||||
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
|
||||
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
|
||||
|
||||
assert (
|
||||
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
|
||||
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
|
||||
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||
|
||||
if gather_pad > 0:
|
||||
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
|
||||
|
||||
return input_
|
||||
232
videosys/core/pab_mgr.py
Normal file
232
videosys/core/pab_mgr.py
Normal file
@ -0,0 +1,232 @@
|
||||
|
||||
PAB_MANAGER = None
|
||||
|
||||
|
||||
class PABConfig:
|
||||
def __init__(
|
||||
self,
|
||||
steps: int,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = None,
|
||||
cross_range: int = None,
|
||||
spatial_broadcast: bool = False,
|
||||
spatial_threshold: list = None,
|
||||
spatial_range: int = None,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = None,
|
||||
temporal_range: int = None,
|
||||
mlp_broadcast: bool = False,
|
||||
mlp_spatial_broadcast_config: dict = None,
|
||||
mlp_temporal_broadcast_config: dict = None,
|
||||
):
|
||||
self.steps = steps
|
||||
|
||||
self.cross_broadcast = cross_broadcast
|
||||
self.cross_threshold = cross_threshold
|
||||
self.cross_range = cross_range
|
||||
|
||||
self.spatial_broadcast = spatial_broadcast
|
||||
self.spatial_threshold = spatial_threshold
|
||||
self.spatial_range = spatial_range
|
||||
|
||||
self.temporal_broadcast = temporal_broadcast
|
||||
self.temporal_threshold = temporal_threshold
|
||||
self.temporal_range = temporal_range
|
||||
|
||||
self.mlp_broadcast = mlp_broadcast
|
||||
self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
|
||||
self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
|
||||
self.mlp_temporal_outputs = {}
|
||||
self.mlp_spatial_outputs = {}
|
||||
|
||||
|
||||
class PABManager:
|
||||
def __init__(self, config: PABConfig):
|
||||
self.config: PABConfig = config
|
||||
|
||||
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
|
||||
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
|
||||
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
|
||||
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
|
||||
init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
|
||||
print(init_prompt)
|
||||
|
||||
def if_broadcast_cross(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.cross_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.cross_range != 0)
|
||||
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_temporal(self, timestep: int, count: int):
|
||||
if (
|
||||
self.config.temporal_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.temporal_range != 0)
|
||||
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
||||
if (
|
||||
self.config.spatial_broadcast
|
||||
and (timestep is not None)
|
||||
and (count % self.config.spatial_range != 0)
|
||||
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
||||
):
|
||||
flag = True
|
||||
else:
|
||||
flag = False
|
||||
count = (count + 1) % self.config.steps
|
||||
return flag, count
|
||||
|
||||
@staticmethod
|
||||
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
||||
is_t_in_skip_config = False
|
||||
skip_range = None
|
||||
for key in config:
|
||||
if key not in all_timesteps:
|
||||
continue
|
||||
index = all_timesteps.index(key)
|
||||
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
|
||||
if timestep in skip_range:
|
||||
is_t_in_skip_config = True
|
||||
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
|
||||
break
|
||||
return is_t_in_skip_config, skip_range
|
||||
|
||||
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not self.config.mlp_broadcast:
|
||||
return False, None, False, None
|
||||
|
||||
if is_temporal:
|
||||
cur_config = self.config.mlp_temporal_broadcast_config
|
||||
else:
|
||||
cur_config = self.config.mlp_spatial_broadcast_config
|
||||
|
||||
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
||||
next_flag = False
|
||||
if (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (timestep in cur_config)
|
||||
and (block_idx in cur_config[timestep]["block"])
|
||||
):
|
||||
flag = False
|
||||
next_flag = True
|
||||
count = count + 1
|
||||
elif (
|
||||
self.config.mlp_broadcast
|
||||
and (timestep is not None)
|
||||
and (is_t_in_skip_config)
|
||||
and (block_idx in cur_config[skip_range[0]]["block"])
|
||||
):
|
||||
flag = True
|
||||
count = 0
|
||||
else:
|
||||
flag = False
|
||||
|
||||
return flag, count, next_flag, skip_range
|
||||
|
||||
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
||||
if is_temporal:
|
||||
self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
|
||||
else:
|
||||
self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
|
||||
|
||||
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
||||
skip_start_t = skip_range[0]
|
||||
if is_temporal:
|
||||
skip_output = (
|
||||
self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_temporal_outputs is not None
|
||||
else None
|
||||
)
|
||||
else:
|
||||
skip_output = (
|
||||
self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
|
||||
if self.config.mlp_spatial_outputs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if skip_output is not None:
|
||||
if timestep == skip_range[-1]:
|
||||
# TODO: save memory
|
||||
if is_temporal:
|
||||
del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
||||
)
|
||||
|
||||
return skip_output
|
||||
|
||||
def get_spatial_mlp_outputs(self):
|
||||
return self.config.mlp_spatial_outputs
|
||||
|
||||
def get_temporal_mlp_outputs(self):
|
||||
return self.config.mlp_temporal_outputs
|
||||
|
||||
|
||||
def set_pab_manager(config: PABConfig):
|
||||
global PAB_MANAGER
|
||||
PAB_MANAGER = PABManager(config)
|
||||
|
||||
|
||||
def enable_pab():
|
||||
if PAB_MANAGER is None:
|
||||
return False
|
||||
return (
|
||||
PAB_MANAGER.config.cross_broadcast
|
||||
or PAB_MANAGER.config.spatial_broadcast
|
||||
or PAB_MANAGER.config.temporal_broadcast
|
||||
)
|
||||
|
||||
|
||||
def update_steps(steps: int):
|
||||
if PAB_MANAGER is not None:
|
||||
PAB_MANAGER.config.steps = steps
|
||||
|
||||
|
||||
def if_broadcast_cross(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_cross(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_temporal(timestep: int, count: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
||||
|
||||
|
||||
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
||||
|
||||
|
||||
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
||||
if not enable_pab():
|
||||
return False, count
|
||||
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
|
||||
|
||||
|
||||
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
|
||||
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
|
||||
|
||||
|
||||
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
||||
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
||||
52
videosys/core/pipeline.py
Normal file
52
videosys/core/pipeline.py
Normal file
@ -0,0 +1,52 @@
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
class VideoSysPipeline(DiffusionPipeline):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def set_eval_and_device(device: torch.device, *modules):
|
||||
for module in modules:
|
||||
module.eval()
|
||||
module.to(device)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
In diffusers, it is a convention to call the pipeline object.
|
||||
But in VideoSys, we will use the generate method for better prompt.
|
||||
This is a wrapper for the generate method to support the diffusers usage.
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
# modify: remove the config module from the expected modules
|
||||
expected_modules = expected_modules - {"config"}
|
||||
|
||||
optional_names = list(optional_parameters)
|
||||
for name in optional_names:
|
||||
if name in cls._optional_components:
|
||||
expected_modules.add(name)
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoSysPipelineOutput(BaseOutput):
|
||||
video: torch.Tensor
|
||||
0
videosys/modules/__init__.py
Normal file
0
videosys/modules/__init__.py
Normal file
3
videosys/modules/activations.py
Normal file
3
videosys/modules/activations.py
Normal file
@ -0,0 +1,3 @@
|
||||
import torch.nn as nn
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
205
videosys/modules/attentions.py
Normal file
205
videosys/modules/attentions.py
Normal file
@ -0,0 +1,205 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from videosys.models.modules.normalization import LlamaRMSNorm
|
||||
|
||||
|
||||
class OpenSoraAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = LlamaRMSNorm,
|
||||
enable_flash_attn: bool = False,
|
||||
rope=None,
|
||||
qk_norm_legacy: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.qk_norm_legacy = qk_norm_legacy
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.rope = False
|
||||
if rope is not None:
|
||||
self.rope = True
|
||||
self.rotary_emb = rope
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
# flash attn is not memory efficient for small sequences, this is empirical
|
||||
enable_flash_attn = self.enable_flash_attn and (N > B)
|
||||
qkv = self.qkv(x)
|
||||
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
if self.qk_norm_legacy:
|
||||
# WARNING: this may be a bug
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
else:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope:
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
if enable_flash_attn:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# (B, #heads, N, #dim) -> (B, N, #heads, #dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
else:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x_output_shape = (B, N, C)
|
||||
if not enable_flash_attn:
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(x_output_shape)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class OpenSoraMultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flash_attn=False):
|
||||
super(OpenSoraMultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(d_model, d_model)
|
||||
self.kv_linear = nn.Linear(d_model, d_model * 2)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(d_model, d_model)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.enable_flash_attn = enable_flash_attn
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
if self.enable_flash_attn:
|
||||
x = self.flash_attn_impl(q, k, v, mask, B, N, C)
|
||||
else:
|
||||
x = self.torch_impl(q, k, v, mask, B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def flash_attn_impl(self, q, k, v, mask, B, N, C):
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(mask)
|
||||
|
||||
x = flash_attn_varlen_func(
|
||||
q.view(-1, self.num_heads, self.head_dim),
|
||||
k.view(-1, self.num_heads, self.head_dim),
|
||||
v.view(-1, self.num_heads, self.head_dim),
|
||||
cu_seqlens_q=q_seqinfo.seqstart.cuda(),
|
||||
cu_seqlens_k=k_seqinfo.seqstart.cuda(),
|
||||
max_seqlen_q=q_seqinfo.max_seqlen,
|
||||
max_seqlen_k=k_seqinfo.max_seqlen,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.view(B, N, C)
|
||||
return x
|
||||
|
||||
def torch_impl(self, q, k, v, mask, B, N, C):
|
||||
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_mask = torch.zeros(B, 1, N, k.shape[2], dtype=torch.bool, device=q.device)
|
||||
for i, m in enumerate(mask):
|
||||
attn_mask[i, :, :, :m] = -1e9
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = out.transpose(1, 2).contiguous().view(B, N, C)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
from xformers
|
||||
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
71
videosys/modules/downsampling.py
Normal file
71
videosys/modules/downsampling.py
Normal file
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXDownsample3D(nn.Module):
|
||||
# Todo: Wait for paper relase.
|
||||
r"""
|
||||
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `2`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `0`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
412
videosys/modules/embeddings.py
Normal file
412
videosys/modules/embeddings.py
Normal file
@ -0,0 +1,412 @@
|
||||
import functools
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
text_embeds (`torch.Tensor`):
|
||||
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
||||
image_embeds (`torch.Tensor`):
|
||||
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
return embeds
|
||||
|
||||
|
||||
class OpenSoraPatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: (2,4,4).
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 4, 4),
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, D, H, W = x.size()
|
||||
if W % self.patch_size[2] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
|
||||
if H % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
|
||||
if D % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # (B C T H W)
|
||||
if self.norm is not None:
|
||||
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
||||
freqs = freqs.to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
if t_freq.dtype != dtype:
|
||||
t_freq = t_freq.to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs // s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
|
||||
s_emb = self.mlp(s_freq)
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class OpenSoraCaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
uncond_prob,
|
||||
act_layer=nn.GELU(approximate="tanh"),
|
||||
token_num=120,
|
||||
):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels,
|
||||
hidden_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
act_layer=act_layer,
|
||||
drop=0,
|
||||
)
|
||||
self.register_buffer(
|
||||
"y_embedding",
|
||||
torch.randn(token_num, in_channels) / in_channels**0.5,
|
||||
)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class OpenSoraPositionEmbedding2D(nn.Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert dim % 4 == 0, "dim must be divisible by 4"
|
||||
half_dim = dim // 2
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def _get_sin_cos_emb(self, t: torch.Tensor):
|
||||
out = torch.einsum("i,d->id", t, self.inv_freq)
|
||||
emb_cos = torch.cos(out)
|
||||
emb_sin = torch.sin(out)
|
||||
return torch.cat((emb_sin, emb_cos), dim=-1)
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def _get_cached_emb(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: float = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
):
|
||||
grid_h = torch.arange(h, device=device) / scale
|
||||
grid_w = torch.arange(w, device=device) / scale
|
||||
if base_size is not None:
|
||||
grid_h *= base_size / h
|
||||
grid_w *= base_size / w
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
grid_w,
|
||||
grid_h,
|
||||
indexing="ij",
|
||||
) # here w goes first
|
||||
grid_h = grid_h.t().reshape(-1)
|
||||
grid_w = grid_w.t().reshape(-1)
|
||||
emb_h = self._get_sin_cos_emb(grid_h)
|
||||
emb_w = self._get_sin_cos_emb(grid_w)
|
||||
return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
h: int,
|
||||
w: int,
|
||||
scale: Optional[float] = 1.0,
|
||||
base_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Use for example in Lumina
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Use for example in Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
102
videosys/modules/normalization.py
Normal file
102
videosys/modules/normalization.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_dim: int,
|
||||
embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if self.emb is not None:
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
else:
|
||||
scale, shift = temb.chunk(2, dim=0)
|
||||
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
67
videosys/modules/upsampling.py
Normal file
67
videosys/modules/upsampling.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CogVideoXUpsample3D(nn.Module):
|
||||
r"""
|
||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of channels in the input image.
|
||||
out_channels (`int`):
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size (`int`, defaults to `3`):
|
||||
Size of the convolving kernel.
|
||||
stride (`int`, defaults to `1`):
|
||||
Stride of the convolution.
|
||||
padding (`int`, defaults to `1`):
|
||||
Padding added to all four sides of the input.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to compress the time dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = F.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
32
videosys/utils/logging.py
Normal file
32
videosys/utils/logging.py
Normal file
@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
import torch.distributed as dist
|
||||
from rich.logging import RichHandler
|
||||
|
||||
|
||||
def create_logger():
|
||||
"""
|
||||
Create a logger that writes to a log file and stdout.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
return logger
|
||||
|
||||
|
||||
def init_dist_logger():
|
||||
"""
|
||||
Update the logger to write to a log file.
|
||||
"""
|
||||
global logger
|
||||
if dist.get_rank() == 0:
|
||||
logger = logging.getLogger(__name__)
|
||||
handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
|
||||
formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
else: # dummy logger (does nothing)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
logger = create_logger()
|
||||
92
videosys/utils/utils.py
Normal file
92
videosys/utils/utils.py
Normal file
@ -0,0 +1,92 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
|
||||
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
|
||||
"""
|
||||
Set requires_grad flag for all parameters in a model.
|
||||
"""
|
||||
for p in model.parameters():
|
||||
p.requires_grad = flag
|
||||
|
||||
|
||||
def set_seed(seed, dp_rank=None):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 1000000)
|
||||
|
||||
if dp_rank is not None:
|
||||
seed = torch.tensor(seed, dtype=torch.int64).cuda()
|
||||
if dist.get_world_size() > 1:
|
||||
dist.broadcast(seed, 0)
|
||||
seed = seed + dp_rank
|
||||
|
||||
seed = int(seed)
|
||||
random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
def str_to_dtype(x: str):
|
||||
if x == "fp32":
|
||||
return torch.float32
|
||||
elif x == "fp16":
|
||||
return torch.float16
|
||||
elif x == "bf16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
|
||||
|
||||
|
||||
def batch_func(func, *args):
|
||||
"""
|
||||
Apply a function to each element of a batch.
|
||||
"""
|
||||
batch = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
|
||||
batch.append(func(arg))
|
||||
else:
|
||||
batch.append(arg)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def merge_args(args1, args2):
|
||||
"""
|
||||
Merge two argparse Namespace objects.
|
||||
"""
|
||||
if args2 is None:
|
||||
return args1
|
||||
|
||||
for k in args2._content.keys():
|
||||
if k in args1.__dict__:
|
||||
v = getattr(args2, k)
|
||||
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
|
||||
v = OmegaConf.to_object(v)
|
||||
setattr(args1, k, v)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown argument {k}")
|
||||
|
||||
return args1
|
||||
|
||||
|
||||
def all_exists(paths):
|
||||
return all(os.path.exists(path) for path in paths)
|
||||
|
||||
|
||||
def save_video(video, output_path, fps):
|
||||
"""
|
||||
Save a video to disk.
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
imageio.mimwrite(output_path, video, fps=fps)
|
||||
Loading…
x
Reference in New Issue
Block a user