from utils import generate_func, read_prompt_list from videosys import LatteConfig, VideoSysEngine import torch from einops import rearrange, repeat from torch import nn import numpy as np from typing import Any, Dict, Optional, Tuple from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence from videosys.models.transformers.latte_transformer_3d import Transformer3DModelOutput from videosys.utils.utils import batch_func from functools import partial def teacache_forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, all_timesteps=None, encoder_hidden_states: Optional[torch.Tensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_image_num: int = 0, enable_temporal_attentions: bool = True, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 0. Split batch for data parallelism if self.parallel_manager.cp_size > 1: ( hidden_states, timestep, encoder_hidden_states, added_cond_kwargs, class_labels, attention_mask, encoder_attention_mask, ) = batch_func( partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), hidden_states, timestep, encoder_hidden_states, added_cond_kwargs, class_labels, attention_mask, encoder_attention_mask, ) input_batch_size, c, frame, h, w = hidden_states.shape frame = frame - use_image_num hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() org_timestep = timestep # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous() elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] encoder_attention_mask_video = repeat( encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame ).contiguous() encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1) # Retrieve lora scale. cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_patches: # here height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size num_patches = height * width hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings if self.adaln_single is not None: if self.use_additional_conditions and added_cond_kwargs is None: raise ValueError( "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." ) # batch_size = hidden_states.shape[0] batch_size = input_batch_size timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 if use_image_num != 0 and self.training: encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] encoder_hidden_states_video = repeat( encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame ).contiguous() encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous() else: encoder_hidden_states_spatial = repeat( encoder_hidden_states, "b t d -> (b f) t d", f=frame ).contiguous() # prepare timesteps for spatial and temporal block timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous() if self.enable_teacache: inp = hidden_states.clone() batch_size = inp.shape[0] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1) ).chunk(6, dim=1) modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]: should_calc = True self.accumulated_rel_l1_distance = 0 else: coefficients = [-2.46434137e+03, 3.08044764e+02, 8.07447667e+01, -4.11385132e+00, 1.11001402e-01] rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp if self.enable_teacache: if not should_calc: hidden_states += self.previous_residual else: if self.parallel_manager.sp_size > 1: set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) temp_pos_embed = split_sequence( self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed hidden_states_origin = hidden_states.clone().detach() for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, use_reentrant=False, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0: # image-video joitn training hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] if i == 0: hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = torch.utils.checkpoint.checkpoint( temp_block, hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, use_reentrant=False, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, use_reentrant=False, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: hidden_states = spatial_block( hidden_states, attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, None, org_timestep, all_timesteps=all_timesteps, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0 and self.training: hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] hidden_states_video = temp_block( hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, org_timestep, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: if i == 0 and frame > 1: hidden_states = hidden_states + temp_pos_embed hidden_states = temp_block( hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, org_timestep, all_timesteps=all_timesteps, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() self.previous_residual = hidden_states - hidden_states_origin else: if self.parallel_manager.sp_size > 1: set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group) set_pad("spatial", num_patches, self.parallel_manager.sp_group) hidden_states = self.split_from_second_dim(hidden_states, input_batch_size) encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size) timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size) temp_pos_embed = split_sequence( self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") ) else: temp_pos_embed = self.temp_pos_embed for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, use_reentrant=False, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0: # image-video joitn training hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] if i == 0: hidden_states_video = hidden_states_video + temp_pos_embed hidden_states_video = torch.utils.checkpoint.checkpoint( temp_block, hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, use_reentrant=False, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: if i == 0: hidden_states = hidden_states + temp_pos_embed hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, use_reentrant=False, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: hidden_states = spatial_block( hidden_states, attention_mask, encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, cross_attention_kwargs, class_labels, None, org_timestep, all_timesteps=all_timesteps, ) if enable_temporal_attentions: hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous() if use_image_num != 0 and self.training: hidden_states_video = hidden_states[:, :frame, ...] hidden_states_image = hidden_states[:, frame:, ...] hidden_states_video = temp_block( hidden_states_video, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, org_timestep, ) hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() else: if i == 0 and frame > 1: hidden_states = hidden_states + temp_pos_embed hidden_states = temp_block( hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, cross_attention_kwargs, class_labels, org_timestep, all_timesteps=all_timesteps, ) hidden_states = rearrange( hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size ).contiguous() if self.parallel_manager.sp_size > 1: if self.enable_teacache: if should_calc: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size) else: hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size) if self.is_input_patches: if self.config.norm_type != "ada_norm_single": conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) elif self.config.norm_type == "ada_norm_single": embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous() shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) # unpatchify if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous() # 3. Gather batch for data parallelism if self.parallel_manager.cp_size > 1: output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) def eval_teacache_slow(prompt_list): config = LatteConfig() engine = VideoSysEngine(config) engine.driver_worker.transformer.__class__.enable_teacache = True engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1 engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 engine.driver_worker.transformer.__class__.previous_modulated_input = None engine.driver_worker.transformer.__class__.previous_residual = None engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/latte_teacache_slow", loop=5) def eval_teacache_fast(prompt_list): config = LatteConfig() engine = VideoSysEngine(config) engine.driver_worker.transformer.__class__.enable_teacache = True engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2 engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 engine.driver_worker.transformer.__class__.previous_modulated_input = None engine.driver_worker.transformer.__class__.previous_residual = None engine.driver_worker.transformer.__class__.forward = teacache_forward generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5) def eval_base(prompt_list): config = LatteConfig() engine = VideoSysEngine(config) generate_func(engine, prompt_list, "./samples/latte_base", loop=5) if __name__ == "__main__": prompt_list = read_prompt_list("vbench/VBench_full_info.json") eval_base(prompt_list) eval_teacache_slow(prompt_list) eval_teacache_fast(prompt_list)