TeaCache/eval/teacache/experiments/opensora_plan.py
2024-12-27 10:55:22 +08:00

594 lines
32 KiB
Python

from utils import generate_func, read_prompt_list
from videosys import OpenSoraPlanConfig, VideoSysEngine
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput
from videosys.utils.utils import batch_func
from functools import partial
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
all_timesteps=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_image_num: int = 0,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch
if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
class_labels,
attention_mask,
encoder_attention_mask,
)
input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num # 20-4=16
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
org_timestep = timestep
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is None:
attention_mask = torch.ones(
(input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
)
attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
dtype = attention_mask.dtype
attention_mask_compress = F.max_pool2d(
attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
)
attention_mask_compress = attention_mask_compress.to(dtype)
attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
# 1 + 4, 1 -> video condition, 4 -> image condition
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
encoder_attention_mask_video = repeat(
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
).contiguous()
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
# Retrieve lora scale.
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_patches: # here
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hw = (height, width)
num_patches = height * width
hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# batch_size = hidden_states.shape[0]
batch_size = input_batch_size
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
if use_image_num != 0 and self.training:
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
encoder_hidden_states_video = repeat(
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
else:
encoder_hidden_states_spatial = repeat(
encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
).contiguous()
# prepare timesteps for spatial and temporal block
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
pos_hw, pos_t = None, None
if self.use_rope:
pos_hw, pos_t = self.make_position(
input_batch_size, frame, use_image_num, height, width, hidden_states.device
)
if self.enable_teacache:
inp = hidden_states.clone()
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.05943668e+05, -1.48759286e+04, 3.06085986e+02, 1.31418080e+00, 2.39658469e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
ori_hidden_states = hidden_states.clone().detach()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
# b c f h w, f = 16 + 4
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
# hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
self.previous_residual = hidden_states - ori_hidden_states
else:
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
pos_hw,
pos_hw,
hw,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
# b c f h w, f = 16 + 4
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
# if i == 0 and not self.use_rope:
# hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
# if i == 0 and not self.use_rope:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
pos_t,
pos_t,
(frame,),
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
else:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
def eval_teacache_slow(prompt_list):
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
def eval_base(prompt_list):
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", )
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
eval_teacache_fast(prompt_list)