support Mochi, LTX-Video and LuminaT2X

This commit is contained in:
LiewFeng 2024-12-30 16:27:11 +08:00
parent 04bb381962
commit 1857e57216
11 changed files with 695 additions and 2 deletions

View File

@ -64,6 +64,7 @@
## Latest News 🔥
- **Welcome for PRs to support other models. Please star ⭐ our project and stay tuned.**
- [2024/12/30] 🔥 Support [Mochi](https://github.com/genmoai/mochi) and [LTX-Video](https://github.com/Lightricks/LTX-Video) for Video Diffusion Models. Support [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) Image Diffusion Models.
- [2024/12/27] 🔥 Support [FLUX](https://github.com/black-forest-labs/flux). TeaCache works well for Image Diffusion Models!
- [2024/12/26] 🔥 Support [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks [@SHYuanBest](https://github.com/SHYuanBest).
- [2024/12/24] 🔥 Support [HunyuanVideo](https://github.com/Tencent/HunyuanVideo).
@ -85,6 +86,18 @@ Please refer to [TeaCache4ConsisID](./TeaCache4ConsisID/README.md).
Please refer to [TeaCache4FLUX](./TeaCache4FLUX/README.md).
## TeaCache for Mochi
Please refer to [TeaCache4Mochi](./TeaCache4Mochi/README.md).
## TeaCache for LTX-Video
Please refer to [TeaCache4LTX-Video](./TeaCache4LTX-Video/README.md).
## TeaCache for LuminaT2X
Please refer to [TeaCache4LuminaT2X](./TeaCache4LuminaT2X/README.md).
## Installation
Prerequisites:
@ -142,12 +155,12 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
```
## Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux). Thanks for their contributions!
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X). Thanks for their contributions!
## License
* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux), please follow thier LICENSE.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), please follow thier LICENSE.
* The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn)
## Citation

View File

@ -0,0 +1,43 @@
<!-- ## **TeaCache4LTX-Video** -->
# TeaCache4LTX-Video
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [LTX-Video](https://github.com/Lightricks/LTX-Video) 2x without much visual quality degradation, in a training-free manner. The following video presents the videos generated by TeaCache-LTX-Video with various rel_l1_thresh values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).
https://github.com/user-attachments/assets/1f4cf26c-b8c6-45e3-b402-840bcd6ba00e
## 📈 Inference Latency Comparisons on a Single A800
| LTX-Video-0.9.1 | TeaCache (0.03) | TeaCache (0.05) |
|:--------------------------:|:----------------------------:|:---------------------:|
| ~32 s | ~20 s | ~16 s |
## Installation
```shell
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
```
## Usage
You can modify the thresh in line 187 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_ltx.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,204 @@
import torch
from diffusers import LTXPipeline
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils import export_to_video
from typing import Any, Dict, Optional, Tuple
import numpy as np
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
inp = self.transformer_blocks[0].norm1(inp)
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
modulated_inp = inp * (1 + scale_msa) + shift_msa
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
self.previous_residual = hidden_states - ori_hidden_states
else:
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
LTXVideoTransformer3DModel.forward = teacache_forward
prompt = "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
seed = 42
num_inference_steps = 50
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
# TeaCache
pipe.transformer.__class__.enable_teacache = True
pipe.transformer.__class__.cnt = 0
pipe.transformer.__class__.num_steps = num_inference_steps
pipe.transformer.__class__.rel_l1_thresh = 0.05 # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
pipe.transformer.__class__.accumulated_rel_l1_distance = 0
pipe.transformer.__class__.previous_modulated_input = None
pipe.transformer.__class__.previous_residual = None
pipe.to("cuda")
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=161,
decode_timestep=0.03,
decode_noise_scale=0.025,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).frames[0]
export_to_video(video, "teacache_ltx_{}.mp4".format(prompt[:50]), fps=24)

View File

@ -0,0 +1,44 @@
<!-- ## **TeaCache4LuminaT2X** -->
# TeaCache4LuminaT2X
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Next with various rel_l1_thresh values: 0 (original), 0.2 (1.5x speedup), 0.3 (1.9x speedup), 0.4 (2.4x speedup), and 0.5 (2.8x speedup).
![visualization](../assets/TeaCache4LuminaT2X.png)
## 📈 Inference Latency Comparisons on a Single A800
| Lumina-Next-SFT | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) |
|:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:|
| ~17 s | ~11 s | ~9 s | ~7 s | ~6 s |
## Installation
```shell
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
pip install flash-attn --no-build-isolation
```
## Usage
You can modify the thresh in line 113 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_lumina_next.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,123 @@
import torch
from typing import Any, Dict, Optional, Tuple, Union
from diffusers import LuminaText2ImgPipeline
from diffusers.models import LuminaNextDiT2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> torch.Tensor:
"""
Forward pass of LuminaNextDiT.
Parameters:
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
"""
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
encoder_mask = encoder_mask.bool()
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for layer in self.layers:
hidden_states = layer(
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb=temb,
cross_attention_kwargs=cross_attention_kwargs,
)
self.previous_residual = hidden_states - ori_hidden_states
else:
for layer in self.layers:
hidden_states = layer(
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb=temb,
cross_attention_kwargs=cross_attention_kwargs,
)
hidden_states = self.norm_out(hidden_states, temb)
# unpatchify
height_tokens = width_tokens = self.patch_size
height, width = img_size[0]
batch_size = hidden_states.size(0)
sequence_length = (height // height_tokens) * (width // width_tokens)
hidden_states = hidden_states[:, :sequence_length].view(
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
LuminaNextDiT2DModel.forward = teacache_forward
num_inference_steps = 30
seed = 1024
prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. "
pipeline = LuminaText2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16).to("cuda")
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
pipeline.transformer.__class__.previous_modulated_input = None
pipeline.transformer.__class__.previous_residual = None
image = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("teacache_lumina_{}.png".format(prompt))

43
TeaCache4Mochi/README.md Normal file
View File

@ -0,0 +1,43 @@
<!-- ## **TeaCache4Mochi** -->
# TeaCache4Mochi
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Mochi](https://github.com/genmoai/mochi) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Mochi with various rel_l1_thresh values: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup).
https://github.com/user-attachments/assets/29a81380-46b3-414f-a96b-6e3acc71b6c4
## 📈 Inference Latency Comparisons on a Single A800
| mochi-1-preview | TeaCache (0.06) | TeaCache (0.09) |
|:--------------------------:|:----------------------------:|:--------------------:|
| ~30 min | ~20 min | ~14 min |
## Installation
```shell
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
```
## Usage
You can modify the thresh in line 174 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_mochi.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [Mochi](https://github.com/genmoai/mochi) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,223 @@
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from diffusers import MochiPipeline
from diffusers.models.transformers import MochiTransformer3DModel
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
from typing import Any, Dict, Optional, Tuple
import numpy as np
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
self.previous_residual = hidden_states - ori_hidden_states
else:
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
MochiTransformer3DModel.forward = teacache_forward
prompt = "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. \
A beige string bag sits beside the bowl, adding a rustic touch to the scene. \
Additional lemons, one halved, are scattered around the base of the bowl. \
The even lighting enhances the vibrant colors and creates a fresh, \
inviting atmosphere."
num_inference_steps = 64
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
# TeaCache
pipe.transformer.__class__.enable_teacache = True
pipe.transformer.__class__.cnt = 0
pipe.transformer.__class__.num_steps = num_inference_steps
pipe.transformer.__class__.rel_l1_thresh = 0.09 # 0.06 for 1.5x speedup, 0.09 for 2.1x speedup
pipe.transformer.__class__.accumulated_rel_l1_distance = 0
pipe.transformer.__class__.previous_modulated_input = None
pipe.transformer.__class__.previous_residual = None
# Enable memory savings
pipe.enable_vae_tiling()
pipe.enable_model_cpu_offload()
with torch.no_grad():
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
pipe.encode_prompt(prompt=prompt)
)
with torch.autocast("cuda", torch.bfloat16):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
frames = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
guidance_scale=4.5,
num_inference_steps=num_inference_steps,
height=480,
width=848,
num_frames=163,
generator=torch.Generator("cuda").manual_seed(0),
output_type="latent",
return_dict=False,
)[0]
video_processor = VideoProcessor(vae_scale_factor=8)
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
latents_std = (
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
)
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
else:
frames = frames / pipe.vae.config.scaling_factor
with torch.no_grad():
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
video = video_processor.postprocess_video(video)[0]
export_to_video(video, "teacache_mochi__{}.mp4".format(prompt[:50]), fps=30)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 MiB