mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-08 20:34:24 +08:00
support Mochi, LTX-Video and LuminaT2X
This commit is contained in:
parent
04bb381962
commit
1857e57216
17
README.md
17
README.md
@ -64,6 +64,7 @@
|
||||
|
||||
## Latest News 🔥
|
||||
- **Welcome for PRs to support other models. Please star ⭐ our project and stay tuned.**
|
||||
- [2024/12/30] 🔥 Support [Mochi](https://github.com/genmoai/mochi) and [LTX-Video](https://github.com/Lightricks/LTX-Video) for Video Diffusion Models. Support [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) Image Diffusion Models.
|
||||
- [2024/12/27] 🔥 Support [FLUX](https://github.com/black-forest-labs/flux). TeaCache works well for Image Diffusion Models!
|
||||
- [2024/12/26] 🔥 Support [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks [@SHYuanBest](https://github.com/SHYuanBest).
|
||||
- [2024/12/24] 🔥 Support [HunyuanVideo](https://github.com/Tencent/HunyuanVideo).
|
||||
@ -85,6 +86,18 @@ Please refer to [TeaCache4ConsisID](./TeaCache4ConsisID/README.md).
|
||||
|
||||
Please refer to [TeaCache4FLUX](./TeaCache4FLUX/README.md).
|
||||
|
||||
## TeaCache for Mochi
|
||||
|
||||
Please refer to [TeaCache4Mochi](./TeaCache4Mochi/README.md).
|
||||
|
||||
## TeaCache for LTX-Video
|
||||
|
||||
Please refer to [TeaCache4LTX-Video](./TeaCache4LTX-Video/README.md).
|
||||
|
||||
## TeaCache for LuminaT2X
|
||||
|
||||
Please refer to [TeaCache4LuminaT2X](./TeaCache4LuminaT2X/README.md).
|
||||
|
||||
## Installation
|
||||
|
||||
Prerequisites:
|
||||
@ -142,12 +155,12 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
|
||||
```
|
||||
## Acknowledgement
|
||||
|
||||
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux). Thanks for their contributions!
|
||||
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X). Thanks for their contributions!
|
||||
|
||||
## License
|
||||
|
||||
* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
|
||||
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux), please follow thier LICENSE.
|
||||
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), please follow thier LICENSE.
|
||||
* The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn)
|
||||
|
||||
## Citation
|
||||
|
||||
43
TeaCache4LTX-Video/README.md
Normal file
43
TeaCache4LTX-Video/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
<!-- ## **TeaCache4LTX-Video** -->
|
||||
# TeaCache4LTX-Video
|
||||
|
||||
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [LTX-Video](https://github.com/Lightricks/LTX-Video) 2x without much visual quality degradation, in a training-free manner. The following video presents the videos generated by TeaCache-LTX-Video with various rel_l1_thresh values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).
|
||||
|
||||
https://github.com/user-attachments/assets/1f4cf26c-b8c6-45e3-b402-840bcd6ba00e
|
||||
|
||||
## 📈 Inference Latency Comparisons on a Single A800
|
||||
|
||||
|
||||
| LTX-Video-0.9.1 | TeaCache (0.03) | TeaCache (0.05) |
|
||||
|:--------------------------:|:----------------------------:|:---------------------:|
|
||||
| ~32 s | ~20 s | ~16 s |
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You can modify the thresh in line 187 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
|
||||
|
||||
```bash
|
||||
python teacache_ltx.py
|
||||
```
|
||||
|
||||
## Citation
|
||||
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
||||
|
||||
```
|
||||
@article{liu2024timestep,
|
||||
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
||||
journal={arXiv preprint arXiv:2411.19108},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We would like to thank the contributors to the [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Diffusers](https://github.com/huggingface/diffusers).
|
||||
204
TeaCache4LTX-Video/teacache_ltx.py
Normal file
204
TeaCache4LTX-Video/teacache_ltx.py
Normal file
@ -0,0 +1,204 @@
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils import export_to_video
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
temb, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
temb = temb.view(batch_size, -1, temb.size(-1))
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
temb_ = temb.clone()
|
||||
inp = self.transformer_blocks[0].norm1(inp)
|
||||
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
|
||||
ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
modulated_inp = inp * (1 + scale_msa) + shift_msa
|
||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.cnt += 1
|
||||
if self.cnt == self.num_steps:
|
||||
self.cnt = 0
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
LTXVideoTransformer3DModel.forward = teacache_forward
|
||||
prompt = "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility."
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
seed = 42
|
||||
num_inference_steps = 50
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
|
||||
# TeaCache
|
||||
pipe.transformer.__class__.enable_teacache = True
|
||||
pipe.transformer.__class__.cnt = 0
|
||||
pipe.transformer.__class__.num_steps = num_inference_steps
|
||||
pipe.transformer.__class__.rel_l1_thresh = 0.05 # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
|
||||
pipe.transformer.__class__.accumulated_rel_l1_distance = 0
|
||||
pipe.transformer.__class__.previous_modulated_input = None
|
||||
pipe.transformer.__class__.previous_residual = None
|
||||
|
||||
pipe.to("cuda")
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).frames[0]
|
||||
export_to_video(video, "teacache_ltx_{}.mp4".format(prompt[:50]), fps=24)
|
||||
44
TeaCache4LuminaT2X/README.md
Normal file
44
TeaCache4LuminaT2X/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
<!-- ## **TeaCache4LuminaT2X** -->
|
||||
# TeaCache4LuminaT2X
|
||||
|
||||
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Next with various rel_l1_thresh values: 0 (original), 0.2 (1.5x speedup), 0.3 (1.9x speedup), 0.4 (2.4x speedup), and 0.5 (2.8x speedup).
|
||||
|
||||

|
||||
|
||||
## 📈 Inference Latency Comparisons on a Single A800
|
||||
|
||||
|
||||
| Lumina-Next-SFT | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) |
|
||||
|:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:|
|
||||
| ~17 s | ~11 s | ~9 s | ~7 s | ~6 s |
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You can modify the thresh in line 113 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
|
||||
|
||||
```bash
|
||||
python teacache_lumina_next.py
|
||||
```
|
||||
|
||||
## Citation
|
||||
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
||||
|
||||
```
|
||||
@article{liu2024timestep,
|
||||
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
||||
journal={arXiv preprint arXiv:2411.19108},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We would like to thank the contributors to the [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) and [Diffusers](https://github.com/huggingface/diffusers).
|
||||
123
TeaCache4LuminaT2X/teacache_lumina_next.py
Normal file
123
TeaCache4LuminaT2X/teacache_lumina_next.py
Normal file
@ -0,0 +1,123 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
from diffusers.models import LuminaNextDiT2DModel
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
import numpy as np
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
return_dict=True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of LuminaNextDiT.
|
||||
|
||||
Parameters:
|
||||
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
|
||||
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
|
||||
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
|
||||
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
|
||||
"""
|
||||
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
|
||||
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
|
||||
|
||||
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
|
||||
|
||||
encoder_mask = encoder_mask.bool()
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
temb_ = temb.clone()
|
||||
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_)
|
||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.cnt += 1
|
||||
if self.cnt == self.num_steps:
|
||||
self.cnt = 0
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
mask,
|
||||
image_rotary_emb,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
temb=temb,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
|
||||
else:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
mask,
|
||||
image_rotary_emb,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
temb=temb,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
# unpatchify
|
||||
height_tokens = width_tokens = self.patch_size
|
||||
height, width = img_size[0]
|
||||
batch_size = hidden_states.size(0)
|
||||
sequence_length = (height // height_tokens) * (width // width_tokens)
|
||||
hidden_states = hidden_states[:, :sequence_length].view(
|
||||
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
|
||||
)
|
||||
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
LuminaNextDiT2DModel.forward = teacache_forward
|
||||
num_inference_steps = 30
|
||||
seed = 1024
|
||||
prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. "
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
# TeaCache
|
||||
pipeline.transformer.__class__.enable_teacache = True
|
||||
pipeline.transformer.__class__.cnt = 0
|
||||
pipeline.transformer.__class__.num_steps = num_inference_steps
|
||||
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
|
||||
pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
|
||||
pipeline.transformer.__class__.previous_modulated_input = None
|
||||
pipeline.transformer.__class__.previous_residual = None
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=torch.Generator("cpu").manual_seed(seed)
|
||||
).images[0]
|
||||
image.save("teacache_lumina_{}.png".format(prompt))
|
||||
43
TeaCache4Mochi/README.md
Normal file
43
TeaCache4Mochi/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
<!-- ## **TeaCache4Mochi** -->
|
||||
# TeaCache4Mochi
|
||||
|
||||
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Mochi](https://github.com/genmoai/mochi) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Mochi with various rel_l1_thresh values: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup).
|
||||
|
||||
https://github.com/user-attachments/assets/29a81380-46b3-414f-a96b-6e3acc71b6c4
|
||||
|
||||
## 📈 Inference Latency Comparisons on a Single A800
|
||||
|
||||
|
||||
| mochi-1-preview | TeaCache (0.06) | TeaCache (0.09) |
|
||||
|:--------------------------:|:----------------------------:|:--------------------:|
|
||||
| ~30 min | ~20 min | ~14 min |
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You can modify the thresh in line 174 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
|
||||
|
||||
```bash
|
||||
python teacache_mochi.py
|
||||
```
|
||||
|
||||
## Citation
|
||||
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
||||
|
||||
```
|
||||
@article{liu2024timestep,
|
||||
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
||||
journal={arXiv preprint arXiv:2411.19108},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We would like to thank the contributors to the [Mochi](https://github.com/genmoai/mochi) and [Diffusers](https://github.com/huggingface/diffusers).
|
||||
223
TeaCache4Mochi/teacache_mochi.py
Normal file
223
TeaCache4Mochi/teacache_mochi.py
Normal file
@ -0,0 +1,223 @@
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.models.transformers import MochiTransformer3DModel
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
def teacache_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
temb, encoder_hidden_states = self.time_embed(
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
||||
|
||||
image_rotary_emb = self.rope(
|
||||
self.pos_frequencies,
|
||||
num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
temb_ = temb.clone()
|
||||
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_)
|
||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.cnt += 1
|
||||
if self.cnt == self.num_steps:
|
||||
self.cnt = 0
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
else:
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
|
||||
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
MochiTransformer3DModel.forward = teacache_forward
|
||||
prompt = "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. \
|
||||
A beige string bag sits beside the bowl, adding a rustic touch to the scene. \
|
||||
Additional lemons, one halved, are scattered around the base of the bowl. \
|
||||
The even lighting enhances the vibrant colors and creates a fresh, \
|
||||
inviting atmosphere."
|
||||
num_inference_steps = 64
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
|
||||
|
||||
# TeaCache
|
||||
pipe.transformer.__class__.enable_teacache = True
|
||||
pipe.transformer.__class__.cnt = 0
|
||||
pipe.transformer.__class__.num_steps = num_inference_steps
|
||||
pipe.transformer.__class__.rel_l1_thresh = 0.09 # 0.06 for 1.5x speedup, 0.09 for 2.1x speedup
|
||||
pipe.transformer.__class__.accumulated_rel_l1_distance = 0
|
||||
pipe.transformer.__class__.previous_modulated_input = None
|
||||
pipe.transformer.__class__.previous_residual = None
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_vae_tiling()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
pipe.encode_prompt(prompt=prompt)
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16):
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
frames = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
guidance_scale=4.5,
|
||||
num_inference_steps=num_inference_steps,
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=163,
|
||||
generator=torch.Generator("cuda").manual_seed(0),
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video_processor = VideoProcessor(vae_scale_factor=8)
|
||||
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
frames = frames / pipe.vae.config.scaling_factor
|
||||
|
||||
with torch.no_grad():
|
||||
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
|
||||
|
||||
video = video_processor.postprocess_video(video)[0]
|
||||
export_to_video(video, "teacache_mochi__{}.mp4".format(prompt[:50]), fps=30)
|
||||
BIN
__MACOSX/TeaCache4LTX-Video/._teacache_ltx.py
Normal file
BIN
__MACOSX/TeaCache4LTX-Video/._teacache_ltx.py
Normal file
Binary file not shown.
BIN
__MACOSX/TeaCache4LuminaT2X/._teacache_lumina_next.py
Normal file
BIN
__MACOSX/TeaCache4LuminaT2X/._teacache_lumina_next.py
Normal file
Binary file not shown.
BIN
__MACOSX/TeaCache4Mochi/._teacache_mochi.py
Normal file
BIN
__MACOSX/TeaCache4Mochi/._teacache_mochi.py
Normal file
Binary file not shown.
BIN
assets/TeaCache4LuminaT2X.png
Normal file
BIN
assets/TeaCache4LuminaT2X.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.8 MiB |
Loading…
x
Reference in New Issue
Block a user