2024-12-27 10:55:22 +08:00

535 lines
28 KiB
Python

from utils import generate_func, read_prompt_list
from videosys import LatteConfig, VideoSysEngine
import torch
from einops import rearrange, repeat
from torch import nn
import numpy as np
from typing import Any, Dict, Optional, Tuple
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
from videosys.models.transformers.latte_transformer_3d import Transformer3DModelOutput
from videosys.utils.utils import batch_func
from functools import partial
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: Optional[torch.LongTensor] = None,
all_timesteps=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_image_num: int = 0,
enable_temporal_attentions: bool = True,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 0. Split batch for data parallelism
if self.parallel_manager.cp_size > 1:
(
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
) = batch_func(
partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
hidden_states,
timestep,
encoder_hidden_states,
added_cond_kwargs,
class_labels,
attention_mask,
encoder_attention_mask,
)
input_batch_size, c, frame, h, w = hidden_states.shape
frame = frame - use_image_num
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
org_timestep = timestep
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
encoder_attention_mask = repeat(encoder_attention_mask, "b 1 l -> (b f) 1 l", f=frame).contiguous()
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
encoder_attention_mask_video = repeat(
encoder_attention_mask_video, "b 1 l -> b (1 f) l", f=frame
).contiguous()
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
encoder_attention_mask = rearrange(encoder_attention_mask, "b n l -> (b n) l").contiguous().unsqueeze(1)
# Retrieve lora scale.
cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_patches: # here
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
num_patches = height * width
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# batch_size = hidden_states.shape[0]
batch_size = input_batch_size
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
if use_image_num != 0 and self.training:
encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
encoder_hidden_states_video = repeat(
encoder_hidden_states_video, "b 1 t d -> b (1 f) t d", f=frame
).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states_spatial = rearrange(encoder_hidden_states, "b f t d -> (b f) t d").contiguous()
else:
encoder_hidden_states_spatial = repeat(
encoder_hidden_states, "b t d -> (b f) t d", f=frame
).contiguous()
# prepare timesteps for spatial and temporal block
timestep_spatial = repeat(timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
timestep_temp = repeat(timestep, "b d -> (b p) d", p=num_patches).contiguous()
if self.enable_teacache:
inp = hidden_states.clone()
batch_size = inp.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.transformer_blocks[0].scale_shift_table[None] + timestep_spatial.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
modulated_inp = self.transformer_blocks[0].norm1(inp) * (1 + scale_msa) + shift_msa
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [-2.46434137e+03, 3.08044764e+02, 8.07447667e+01, -4.11385132e+00, 1.11001402e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
hidden_states_origin = hidden_states.clone().detach()
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
None,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0 and frame > 1:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
self.previous_residual = hidden_states - hidden_states_origin
else:
if self.parallel_manager.sp_size > 1:
set_pad("temporal", frame + use_image_num, self.parallel_manager.sp_group)
set_pad("spatial", num_patches, self.parallel_manager.sp_group)
hidden_states = self.split_from_second_dim(hidden_states, input_batch_size)
encoder_hidden_states_spatial = self.split_from_second_dim(encoder_hidden_states_spatial, input_batch_size)
timestep_spatial = self.split_from_second_dim(timestep_spatial, input_batch_size)
temp_pos_embed = split_sequence(
self.temp_pos_embed, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
)
else:
temp_pos_embed = self.temp_pos_embed
for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0: # image-video joitn training
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
if i == 0:
hidden_states_video = hidden_states_video + temp_pos_embed
hidden_states_video = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0:
hidden_states = hidden_states + temp_pos_embed
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
hidden_states = spatial_block(
hidden_states,
attention_mask,
encoder_hidden_states_spatial,
encoder_attention_mask,
timestep_spatial,
cross_attention_kwargs,
class_labels,
None,
org_timestep,
all_timesteps=all_timesteps,
)
if enable_temporal_attentions:
hidden_states = rearrange(hidden_states, "(b f) t d -> (b t) f d", b=input_batch_size).contiguous()
if use_image_num != 0 and self.training:
hidden_states_video = hidden_states[:, :frame, ...]
hidden_states_image = hidden_states[:, frame:, ...]
hidden_states_video = temp_block(
hidden_states_video,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
)
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
else:
if i == 0 and frame > 1:
hidden_states = hidden_states + temp_pos_embed
hidden_states = temp_block(
hidden_states,
None, # attention_mask
None, # encoder_hidden_states
None, # encoder_attention_mask
timestep_temp,
cross_attention_kwargs,
class_labels,
org_timestep,
all_timesteps=all_timesteps,
)
hidden_states = rearrange(
hidden_states, "(b t) f d -> (b f) t d", b=input_batch_size
).contiguous()
if self.parallel_manager.sp_size > 1:
if self.enable_teacache:
if should_calc:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
self.previous_residual = self.gather_from_second_dim(self.previous_residual, input_batch_size)
else:
hidden_states = self.gather_from_second_dim(hidden_states, input_batch_size)
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
embedded_timestep = repeat(embedded_timestep, "b d -> (b f) d", f=frame + use_image_num).contiguous()
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, "(b f) c h w -> b c f h w", b=input_batch_size).contiguous()
# 3. Gather batch for data parallelism
if self.parallel_manager.cp_size > 1:
output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
def eval_teacache_slow(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_slow", loop=5)
def eval_teacache_fast(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
engine.driver_worker.transformer.__class__.enable_teacache = True
engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
engine.driver_worker.transformer.__class__.previous_modulated_input = None
engine.driver_worker.transformer.__class__.previous_residual = None
engine.driver_worker.transformer.__class__.forward = teacache_forward
generate_func(engine, prompt_list, "./samples/latte_teacache_fast", loop=5)
def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base", loop=5)
if __name__ == "__main__":
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
eval_base(prompt_list)
eval_teacache_slow(prompt_list)
eval_teacache_fast(prompt_list)