From 3e4a7faab7eefea69a2cb21012be2499b1168897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=A9=B9?= Date: Tue, 7 Jan 2025 16:44:48 +0800 Subject: [PATCH] support TangoFlux --- README.md | 9 +- TeaCache4TangoFlux/README.md | 41 ++++ TeaCache4TangoFlux/teacache_tango_flux.py | 272 ++++++++++++++++++++++ 3 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 TeaCache4TangoFlux/README.md create mode 100644 TeaCache4TangoFlux/teacache_tango_flux.py diff --git a/README.md b/README.md index 09763ad..4b96812 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ ## Latest News 🔥 - **Welcome for PRs to support other models. Please star ⭐ our project and stay tuned.** +- [2025/01/07] 🔥 Support [TangoFlux](https://github.com/declare-lab/TangoFlux). TeaCache works well for Audio Diffusion Models! Rescaling coefficients for FLUX can be directly applied to TangoFLUX. - [2025/01/06] 🔥 [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) supports TeaCache. Thanks [@kijai](https://github.com/kijai), [ctf05](https://github.com/ctf05) and [DarioFT](https://github.com/DarioFT). - [2024/12/30] 🔥 Support [Mochi](https://github.com/genmoai/mochi) and [LTX-Video](https://github.com/Lightricks/LTX-Video) for Video Diffusion Models. Support [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) for Image Diffusion Models. - [2024/12/27] 🔥 Support [FLUX](https://github.com/black-forest-labs/flux). TeaCache works well for Image Diffusion Models! @@ -99,6 +100,10 @@ Please refer to [TeaCache4LTX-Video](./TeaCache4LTX-Video/README.md). Please refer to [TeaCache4Lumina-T2X](./TeaCache4Lumina-T2X/README.md). +## TeaCache for TangoFlux + +Please refer to [TeaCache4TangoFlux](./TeaCache4TangoFlux/README.md). + ## Installation Prerequisites: @@ -156,12 +161,12 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb ``` ## Acknowledgement -This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X). Thanks for their contributions! +This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) and [TangoFlux](https://github.com/declare-lab/TangoFlux). Thanks for their contributions! ## License * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file. -* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), please follow their LICENSE. +* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), and [TangoFlux](https://github.com/declare-lab/TangoFlux), please follow their LICENSE. * The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn) ## Citation diff --git a/TeaCache4TangoFlux/README.md b/TeaCache4TangoFlux/README.md new file mode 100644 index 0000000..10a0a99 --- /dev/null +++ b/TeaCache4TangoFlux/README.md @@ -0,0 +1,41 @@ + +# TeaCache4TangoFlux + +[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [TangoFlux](https://github.com/declare-lab/TangoFlux) 2x without much audio quality degradation, in a training-free manner. + +## 📈 Inference Latency Comparisons on a Single A800 + + +| TangoFlux | TeaCache (0.25) | TeaCache (0.4) | +|:-------------------:|:----------------------------:|:--------------------:| +| ~4.08 s | ~2.42 s | ~1.95 s | + +## Installation + +```shell +pip install git+https://github.com/declare-lab/TangoFlux +``` + +## Usage + +You can modify the thresh in line 266 to obtain your desired trade-off between latency and audio quality. For single-gpu inference, you can use the following command: + +```bash +python teacache_tango_flux.py +``` + +## Citation +If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. + +``` +@article{liu2024timestep, + title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, + author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, + journal={arXiv preprint arXiv:2411.19108}, + year={2024} +} +``` + +## Acknowledgements + +We would like to thank the contributors to the [TangoFlux](https://github.com/declare-lab/TangoFlux). \ No newline at end of file diff --git a/TeaCache4TangoFlux/teacache_tango_flux.py b/TeaCache4TangoFlux/teacache_tango_flux.py new file mode 100644 index 0000000..b208e08 --- /dev/null +++ b/TeaCache4TangoFlux/teacache_tango_flux.py @@ -0,0 +1,272 @@ +import torchaudio +from tangoflux import TangoFluxInference +from typing import Any, Dict, Optional, Tuple, Union +from diffusers.models import FluxTransformer2DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +import torch +import numpy as np +import random + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def teacache_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pos_embed(ids) + + if self.enable_teacache: + inp = hidden_states.clone() + temb_ = temb.clone() + modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + else: + ori_hidden_states = hidden_states.clone() + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + self.previous_residual = hidden_states - ori_hidden_states + else: + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +FluxTransformer2DModel.forward = teacache_forward +seed = 42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) +torch.backends.cudnn.deterministic = True + +prompt = 'Hammer slowly hitting the wooden table' +steps = 50 + + +model = TangoFluxInference(name='declare-lab/TangoFlux') +# TeaCache +model.model.transformer.__class__.enable_teacache = True +model.model.transformer.__class__.cnt = 0 +model.model.transformer.__class__.num_steps = steps +model.model.transformer.__class__.rel_l1_thresh = 0.25 # 0.25 for 1.7x speedup, 0.4 for 2.1x speedup +model.model.transformer.__class__.accumulated_rel_l1_distance = 0 +model.model.transformer.__class__.previous_modulated_input = None +model.model.transformer.__class__.previous_residual = None + +audio = model.generate(prompt, steps=steps, duration=10) +torchaudio.save('teacache_tango_flux_{}.wav'.format(prompt), audio, 44100) \ No newline at end of file