from utils import generate_func, read_prompt_list from videosys import OpenSoraPlanConfig, VideoSysEngine import torch import torch.nn.functional as F from einops import rearrange, repeat import numpy as np from typing import Any, Dict, Optional, Tuple from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput from videosys.utils.utils import batch_func from functools import partial def teacache_forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, all_timesteps=None, encoder_hidden_states: Optional[torch.Tensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_image_num: int = 0, enable_temporal_attentions: bool = True, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 0. Split batch if self.parallel_manager.cp_size > 1: ( hidden_states, timestep, encoder_hidden_states, class_labels, attention_mask, encoder_attention_mask, ) = batch_func( partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, timestep, encoder_hidden_states, class_labels, attention_mask, encoder_attention_mask, ) input_batch_size, c, frame, h, w = hidden_states.shape frame = frame - use_image_num # 20-4=16 hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() org_timestep = timestep # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is None: attention_mask = torch.ones( (input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype ) attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num) dtype = attention_mask.dtype attention_mask_compress = F.max_pool2d( attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor ) attention_mask_compress = attention_mask_compress.to(dtype) attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype) attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype) # 1 + 4, 1 -> video condition, 4 -> image condition # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous() encoder_attention_mask = encoder_attention_mask.to(self.dtype) elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] encoder_attention_mask_video = repeat( encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame ).contiguous() encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1) encoder_attention_mask = encoder_attention_mask.to(self.dtype) # Retrieve lora scale. cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_patches: # here height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hw = (height, width) num_patches = height * width hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings if self.adaln_single is not None: if self.use_additional_conditions and added_cond_kwargs is None: raise ValueError( "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." ) # batch_size = hidden_states.shape[0] batch_size = input_batch_size timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152 if use_image_num != 0 and self.training: encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] encoder_hidden_states_video = repeat( encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame ).contiguous() encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous() else: encoder_hidden_states_spatial = repeat( encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame ).contiguous() # prepare timesteps for spatial and temporal block timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous() pos_hw, pos_t = None, None if self.use_rope: pos_hw, pos_t = self.make_position( input_batch_size, frame, use_image_num, height, width, hidden_states.device ) if self.enable_teacache: inp = hidden_states.clone() batch_size = hidden_states.shape[0] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1) ).chunk(6, dim=1) modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]: should_calc = True self.accumulated_rel_l1_distance = 0 else: coefficients = [2.05943668e+05, -1.48759286e+04, 3.06085986e+02, 1.31418080e+00, 2.39658469e-03] rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp if self.enable_teacache: if not should_calc: hidden_states += self.previous_residual else: if self.parallel_manager.sp_size > 1: set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) attention_mask = self.split_from_second_dim(attention_mask, input_batch_size) attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size) temp_pos_embed = split_sequence( self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed ori_hidden_states = hidden_states.clone().detach() for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, attention_mask_compress if i >= self.num_layers // 2 else attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, pos_hw, pos_hw, hw, use_reentrant=False, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0: # image-video joitn training hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] # if i == 0 and not self.use_rope: if i == 0: hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = torch.utils.checkpoint.checkpoint( temp_block, hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), use_reentrant=False, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: # if i == 0 and not self.use_rope: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), use_reentrant=False, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: hidden_states = spatial_block( hidden_states, attention_mask_compress if i >= self.num_layers // 2 else attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, pos_hw, pos_hw, hw, org_timestep, all_timesteps=all_timesteps, ) if enable_temporal_attentions: # b c f h w, f = 16 + 4 hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0 and self.training: hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] # if i == 0 and not self.use_rope: # hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = temp_block( hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), org_timestep, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: # if i == 0 and not self.use_rope: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = temp_block( hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), org_timestep, all_timesteps=all_timesteps, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() self.previous_residual = hidden_states - ori_hidden_states else: if self.parallel_manager.sp_size > 1: set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) attention_mask = self.split_from_second_dim(attention_mask, input_batch_size) attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size) temp_pos_embed = split_sequence( self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, attention_mask_compress if i >= self.num_layers // 2 else attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, pos_hw, pos_hw, hw, use_reentrant=False, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0: # image-video joitn training hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] # if i == 0 and not self.use_rope: if i == 0: hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = torch.utils.checkpoint.checkpoint( temp_block, hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), use_reentrant=False, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: # if i == 0 and not self.use_rope: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), use_reentrant=False, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: hidden_states = spatial_block( hidden_states, attention_mask_compress if i >= self.num_layers // 2 else attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, pos_hw, pos_hw, hw, org_timestep, all_timesteps=all_timesteps, ) if enable_temporal_attentions: # b c f h w, f = 16 + 4 hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0 and self.training: hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] # if i == 0 and not self.use_rope: # hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = temp_block( hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), org_timestep, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: # if i == 0 and not self.use_rope: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = temp_block( hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, pos_t, pos_t, (frame,), org_timestep, all_timesteps=all_timesteps, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() if self.parallel_manager.sp_size > 1: if self.enable_teacache: if should_calc: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size) else: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) if self.is_input_patches: if self.config.norm_type != "ada_norm_single": conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) elif self.config.norm_type == "ada_norm_single": embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) # unpatchify if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() # 3. Gather batch for data parallelism if self.parallel_manager.cp_size > 1: output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) def eval_teacache_slow(prompt_list): config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") engine = VideoSysEngine(config) engine.driver_worker.transformer.__class__.enable_teacache = True engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1 engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 engine.driver_worker.transformer.__class__.previous_modulated_input = None engine.driver_worker.transformer.__class__.previous_residual = None engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5) def eval_teacache_fast(prompt_list): config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512") engine = VideoSysEngine(config) engine.driver_worker.transformer.__class__.enable_teacache = True engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2 engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 engine.driver_worker.transformer.__class__.previous_modulated_input = None engine.driver_worker.transformer.__class__.previous_residual = None engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5) def eval_base(prompt_list): config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", ) engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5) if __name__ == "__main__": prompt_list = read_prompt_list("vbench/VBench_full_info.json") eval_base(prompt_list) eval_teacache_slow(prompt_list) eval_teacache_fast(prompt_list)