mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 04:44:23 +08:00
594 lines
32 KiB
Python
594 lines
32 KiB
Python
from utils import generate_func, read_prompt_list
|
|
from videosys import OpenSoraPlanConfig, VideoSysEngine
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
import numpy as np
|
|
from typing import Any, Dict, Optional, Tuple
|
|
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
|
|
from videosys.models.transformers.open_sora_plan_v110_transformer_3d import Transformer3DModelOutput
|
|
from videosys.utils.utils import batch_func
|
|
from functools import partial
|
|
|
|
def teacache_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
all_timesteps=None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
use_image_num: int = 0,
|
|
enable_temporal_attentions: bool = True,
|
|
return_dict: bool = True,
|
|
):
|
|
"""
|
|
The [`Transformer2DModel`] forward method.
|
|
|
|
Args:
|
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
|
|
Input `hidden_states`.
|
|
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
|
self-attention.
|
|
timestep ( `torch.LongTensor`, *optional*):
|
|
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
|
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
|
`AdaLayerZeroNorm`.
|
|
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
`self.processor` in
|
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
attention_mask ( `torch.Tensor`, *optional*):
|
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
|
negative values to the attention scores corresponding to "discard" tokens.
|
|
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
|
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
|
|
|
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
|
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
|
|
|
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
|
above. This bias will be added to the cross-attention scores.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
|
tuple.
|
|
|
|
Returns:
|
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
`tuple` where the first element is the sample tensor.
|
|
"""
|
|
# 0. Split batch
|
|
if self.parallel_manager.cp_size > 1:
|
|
(
|
|
hidden_states,
|
|
timestep,
|
|
encoder_hidden_states,
|
|
class_labels,
|
|
attention_mask,
|
|
encoder_attention_mask,
|
|
) = batch_func(
|
|
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
|
|
hidden_states,
|
|
timestep,
|
|
encoder_hidden_states,
|
|
class_labels,
|
|
attention_mask,
|
|
encoder_attention_mask,
|
|
)
|
|
input_batch_size, c, frame, h, w = hidden_states.shape
|
|
frame = frame - use_image_num # 20-4=16
|
|
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
|
|
org_timestep = timestep
|
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
|
# expects mask of shape:
|
|
# [batch, key_tokens]
|
|
# adds singleton query_tokens dimension:
|
|
# [batch, 1, key_tokens]
|
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
(input_batch_size, frame + use_image_num, h, w), device=hidden_states.device, dtype=hidden_states.dtype
|
|
)
|
|
attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
|
|
dtype = attention_mask.dtype
|
|
attention_mask_compress = F.max_pool2d(
|
|
attention_mask.float(), kernel_size=self.compress_kv_factor, stride=self.compress_kv_factor
|
|
)
|
|
attention_mask_compress = attention_mask_compress.to(dtype)
|
|
|
|
attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
|
|
attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
|
|
|
|
# 1 + 4, 1 -> video condition, 4 -> image condition
|
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
|
|
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
|
|
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
|
|
encoder_attention_mask_video = repeat(
|
|
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
|
|
).contiguous()
|
|
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
|
|
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
|
|
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
|
|
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
|
|
|
|
# Retrieve lora scale.
|
|
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
|
|
|
# 1. Input
|
|
if self.is_input_patches: # here
|
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
|
hw = (height, width)
|
|
num_patches = height * width
|
|
|
|
hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings
|
|
|
|
if self.adaln_single is not None:
|
|
if self.use_additional_conditions and added_cond_kwargs is None:
|
|
raise ValueError(
|
|
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
|
)
|
|
# batch_size = hidden_states.shape[0]
|
|
batch_size = input_batch_size
|
|
timestep, embedded_timestep = self.adaln_single(
|
|
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
|
)
|
|
|
|
# 2. Blocks
|
|
if self.caption_projection is not None:
|
|
batch_size = hidden_states.shape[0]
|
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152
|
|
|
|
if use_image_num != 0 and self.training:
|
|
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
|
|
encoder_hidden_states_video = repeat(
|
|
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
|
|
).contiguous()
|
|
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
|
|
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
|
|
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
|
|
else:
|
|
encoder_hidden_states_spatial = repeat(
|
|
encoder_hidden_states, "b 1 t d -> (b f) t d", f=frame
|
|
).contiguous()
|
|
|
|
# prepare timesteps for spatial and temporal block
|
|
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
|
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
|
|
|
|
pos_hw, pos_t = None, None
|
|
if self.use_rope:
|
|
pos_hw, pos_t = self.make_position(
|
|
input_batch_size, frame, use_image_num, height, width, hidden_states.device
|
|
)
|
|
|
|
if self.enable_teacache:
|
|
inp = hidden_states.clone()
|
|
batch_size = hidden_states.shape[0]
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
|
|
).chunk(6, dim=1)
|
|
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
|
|
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
coefficients = [2.05943668e+05, -1.48759286e+04, 3.06085986e+02, 1.31418080e+00, 2.39658469e-03]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = modulated_inp
|
|
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
hidden_states += self.previous_residual
|
|
else:
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
|
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
|
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
|
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
|
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
|
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
|
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
|
temp_pos_embed = split_sequence(
|
|
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
|
)
|
|
else:
|
|
temp_pos_embed = self.temp_pos_embed
|
|
ori_hidden_states = hidden_states.clone().detach()
|
|
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
|
if self.training and self.gradient_checkpointing:
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
spatial_block,
|
|
hidden_states,
|
|
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
|
encoder_hidden_states_spatial,
|
|
encoder_attention_mask,
|
|
timestep_spatial,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_hw,
|
|
pos_hw,
|
|
hw,
|
|
use_reentrant=False,
|
|
)
|
|
|
|
if enable_temporal_attentions:
|
|
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
|
|
|
if use_image_num != 0: # image-video joitn training
|
|
hidden_states_video = hidden_states[:, :frame, ...]
|
|
hidden_states_image = hidden_states[:, frame:, ...]
|
|
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states_video = hidden_states_video + temp_pos_embed
|
|
|
|
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
|
temp_block,
|
|
hidden_states_video,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
use_reentrant=False,
|
|
)
|
|
|
|
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
|
|
else:
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states = hidden_states + temp_pos_embed
|
|
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
temp_block,
|
|
hidden_states,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
use_reentrant=False,
|
|
)
|
|
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
else:
|
|
hidden_states = spatial_block(
|
|
hidden_states,
|
|
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
|
encoder_hidden_states_spatial,
|
|
encoder_attention_mask,
|
|
timestep_spatial,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_hw,
|
|
pos_hw,
|
|
hw,
|
|
org_timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
if enable_temporal_attentions:
|
|
# b c f h w, f = 16 + 4
|
|
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
|
|
|
if use_image_num != 0 and self.training:
|
|
hidden_states_video = hidden_states[:, :frame, ...]
|
|
hidden_states_image = hidden_states[:, frame:, ...]
|
|
|
|
# if i == 0 and not self.use_rope:
|
|
# hidden_states_video = hidden_states_video + temp_pos_embed
|
|
|
|
hidden_states_video = temp_block(
|
|
hidden_states_video,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
org_timestep,
|
|
)
|
|
|
|
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
|
|
else:
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states = hidden_states + temp_pos_embed
|
|
hidden_states = temp_block(
|
|
hidden_states,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
org_timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
self.previous_residual = hidden_states - ori_hidden_states
|
|
else:
|
|
if self.parallel_manager.sp_size > 1:
|
|
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
|
|
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
|
|
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
|
|
self.previous_residual = self.split_from_second_dim(self.previous_residual, input_batch_size) if self.previous_residual is not None else None
|
|
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
|
|
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
|
|
attention_mask = self.split_from_second_dim(attention_mask, input_batch_size)
|
|
attention_mask_compress = self.split_from_second_dim(attention_mask_compress, input_batch_size)
|
|
temp_pos_embed = split_sequence(
|
|
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
|
|
)
|
|
else:
|
|
temp_pos_embed = self.temp_pos_embed
|
|
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
|
|
if self.training and self.gradient_checkpointing:
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
spatial_block,
|
|
hidden_states,
|
|
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
|
encoder_hidden_states_spatial,
|
|
encoder_attention_mask,
|
|
timestep_spatial,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_hw,
|
|
pos_hw,
|
|
hw,
|
|
use_reentrant=False,
|
|
)
|
|
|
|
if enable_temporal_attentions:
|
|
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
|
|
|
if use_image_num != 0: # image-video joitn training
|
|
hidden_states_video = hidden_states[:, :frame, ...]
|
|
hidden_states_image = hidden_states[:, frame:, ...]
|
|
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states_video = hidden_states_video + temp_pos_embed
|
|
|
|
hidden_states_video = torch.utils.checkpoint.checkpoint(
|
|
temp_block,
|
|
hidden_states_video,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
use_reentrant=False,
|
|
)
|
|
|
|
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
|
|
else:
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states = hidden_states + temp_pos_embed
|
|
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
temp_block,
|
|
hidden_states,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
use_reentrant=False,
|
|
)
|
|
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
else:
|
|
hidden_states = spatial_block(
|
|
hidden_states,
|
|
attention_mask_compress if i >= self.num_layers // 2 else attention_mask,
|
|
encoder_hidden_states_spatial,
|
|
encoder_attention_mask,
|
|
timestep_spatial,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_hw,
|
|
pos_hw,
|
|
hw,
|
|
org_timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
if enable_temporal_attentions:
|
|
# b c f h w, f = 16 + 4
|
|
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
|
|
|
|
if use_image_num != 0 and self.training:
|
|
hidden_states_video = hidden_states[:, :frame, ...]
|
|
hidden_states_image = hidden_states[:, frame:, ...]
|
|
|
|
# if i == 0 and not self.use_rope:
|
|
# hidden_states_video = hidden_states_video + temp_pos_embed
|
|
|
|
hidden_states_video = temp_block(
|
|
hidden_states_video,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
org_timestep,
|
|
)
|
|
|
|
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
|
|
else:
|
|
# if i == 0 and not self.use_rope:
|
|
if i == 0:
|
|
hidden_states = hidden_states + temp_pos_embed
|
|
hidden_states = temp_block(
|
|
hidden_states,
|
|
None, # attention_mask
|
|
None, # encoder_hidden_states
|
|
None, # encoder_attention_mask
|
|
timestep_temp,
|
|
cross_attention_kwargs,
|
|
class_labels,
|
|
pos_t,
|
|
pos_t,
|
|
(frame,),
|
|
org_timestep,
|
|
all_timesteps=all_timesteps,
|
|
)
|
|
|
|
hidden_states = rearrange(
|
|
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
|
|
).contiguous()
|
|
|
|
if self.parallel_manager.sp_size > 1:
|
|
if self.enable_teacache:
|
|
if should_calc:
|
|
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
|
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
|
|
else:
|
|
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
|
|
|
|
if self.is_input_patches:
|
|
if self.config.norm_type != "ada_norm_single":
|
|
conditioning = self.transformer_blocks[0].norm1.emb(
|
|
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
)
|
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
hidden_states = self.proj_out_2(hidden_states)
|
|
elif self.config.norm_type == "ada_norm_single":
|
|
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
|
|
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
|
hidden_states = self.norm_out(hidden_states)
|
|
# Modulation
|
|
hidden_states = hidden_states * (1 + scale) + shift
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# unpatchify
|
|
if self.adaln_single is None:
|
|
height = width = int(hidden_states.shape[1] ** 0.5)
|
|
hidden_states = hidden_states.reshape(
|
|
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
|
)
|
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
output = hidden_states.reshape(
|
|
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
|
)
|
|
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
|
|
|
|
# 3. Gather batch for data parallelism
|
|
if self.parallel_manager.cp_size > 1:
|
|
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
|
|
return Transformer3DModelOutput(sample=output)
|
|
|
|
|
|
def eval_teacache_slow(prompt_list):
|
|
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_slow", loop=5)
|
|
|
|
def eval_teacache_fast(prompt_list):
|
|
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
|
|
engine = VideoSysEngine(config)
|
|
engine.driver_worker.transformer.__class__.enable_teacache = True
|
|
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
|
|
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
engine.driver_worker.transformer.__class__.previous_modulated_input = None
|
|
engine.driver_worker.transformer.__class__.previous_residual = None
|
|
engine.driver_worker.transformer.__class__.forward = teacache_forward
|
|
generate_func(engine, prompt_list, "./samples/opensoraplan_teacache_fast", loop=5)
|
|
|
|
|
|
def eval_base(prompt_list):
|
|
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", )
|
|
engine = VideoSysEngine(config)
|
|
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
|
eval_base(prompt_list)
|
|
eval_teacache_slow(prompt_list)
|
|
eval_teacache_fast(prompt_list)
|
|
|