support FLUX

This commit is contained in:
LiewFeng 2024-12-27 19:35:50 +08:00
parent 2b8b201b3c
commit bceec3f9ab
6 changed files with 388 additions and 5 deletions

View File

@ -63,7 +63,8 @@
![visualization](./assets/tisser.png) ![visualization](./assets/tisser.png)
## Latest News 🔥 ## Latest News 🔥
- **TeaCache4FLUX will be released in a few days. Please star ⭐ our project and stay tuned. Welcome for PRs to support other models.** - **Welcome for PRs to support other models. Please star ⭐ our project and stay tuned.**
- [2024/12/27] 🔥 Support [FLUX](https://github.com/black-forest-labs/flux). TeaCache works well for Image Diffusion Models!
- [2024/12/26] 🔥 Support [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks [@SHYuanBest](https://github.com/SHYuanBest). - [2024/12/26] 🔥 Support [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks [@SHYuanBest](https://github.com/SHYuanBest).
- [2024/12/24] 🔥 Support [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). - [2024/12/24] 🔥 Support [HunyuanVideo](https://github.com/Tencent/HunyuanVideo).
- [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo). - [2024/12/19] 🔥 Support [CogVideoX](https://github.com/THUDM/CogVideo).
@ -80,6 +81,10 @@ Please refer to [TeaCache4HunyuanVideo](./TeaCache4HunyuanVideo/README.md).
Please refer to [TeaCache4ConsisID](./TeaCache4ConsisID/README.md). Please refer to [TeaCache4ConsisID](./TeaCache4ConsisID/README.md).
## TeaCache for FLUX
Please refer to [TeaCache4FLUX](./TeaCache4FLUX/README.md).
## Installation ## Installation
Prerequisites: Prerequisites:
@ -137,12 +142,12 @@ python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
``` ```
## Acknowledgement ## Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) and [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). Thanks for their contributions! This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux). Thanks for their contributions!
## License ## License
* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file. * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) and [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), please follow thier LICENSE. * For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) and [FLUX](https://github.com/black-forest-labs/flux), please follow thier LICENSE.
* The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn) * The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn)
## Citation ## Citation

43
TeaCache4FLUX/README.md Normal file
View File

@ -0,0 +1,43 @@
<!-- ## **TeaCache4FLUX** -->
# TeaCache4FLUX
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [FLUX](https://github.com/black-forest-labs/flux) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-FLUX with various `rel_l1_thresh` values: 0 (original), 0.25 (1.5x speedup), 0.4 (1.8x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).
![visualization](../assets/TeaCache4FLUX.png)
## 📈 Inference Latency Comparisons on a Single A800
| FLUX.1 [dev] | TeaCache (0.25) | TeaCache (0.4) | TeaCache (0.6) | TeaCache (0.8) |
|:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:---------------------:|
| ~18 s | ~12 s | ~10 s | ~9s | ~8s |
## Installation
```shell
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
```
## Usage
You can modify the `rel_l1_thresh` in line 320 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_flux.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [FLUX](https://github.com/black-forest-labs/flux) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,335 @@
from typing import Any, Dict, Optional, Tuple, Union
from diffusers import DiffusionPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
import torch
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
if self.enable_teacache:
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
FluxTransformer2DModel.forward = teacache_forward
num_inference_steps = 28
seed = 42
prompt = "An image of a squirrel in Picasso style"
pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.rel_l1_thresh = 0.6 # 0.25 for 1.5x speedup, 0.4 for 1.8x speedup, 0.6 for 2.0x speedup, 0.8 for 2.25x speedup
pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
pipeline.transformer.__class__.previous_modulated_input = None
pipeline.transformer.__class__.previous_residual = None
pipeline.to("cuda")
img = pipeline(
prompt,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
img.save("{}.png".format('TeaCache_' + prompt))

View File

@ -1,7 +1,7 @@
<!-- ## **TeaCache4HunyuanVideo** --> <!-- ## **TeaCache4HunyuanVideo** -->
# TeaCache4HunyuanVideo # TeaCache4HunyuanVideo
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) 2x without much visual quality degradation, in a training-free manner. The following video presents the videos generated by HunyuanVideo, TeaCache (1.6x speedup) and TeaCache (2.1x speedup). [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-HunyuanVideo with various `rel_l1_thresh` values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
https://github.com/user-attachments/assets/7f75f4e2-3d7e-4762-9afe-c5cc3dcabe44 https://github.com/user-attachments/assets/7f75f4e2-3d7e-4762-9afe-c5cc3dcabe44
@ -17,7 +17,7 @@ https://github.com/user-attachments/assets/7f75f4e2-3d7e-4762-9afe-c5cc3dcabe44
## Usage ## Usage
Follow [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) to clone the repo and finish the installation, then copy 'teacache_sample_video.py' in this repo to the HunyuanVideo repo. You can modify the thresh in line 220 to obtain your desired trade-off between latency and visul quality. Follow [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) to clone the repo and finish the installation, then copy 'teacache_sample_video.py' in this repo to the HunyuanVideo repo. You can modify the '`rel_l1_thresh`' in line 220 to obtain your desired trade-off between latency and visul quality.
For single-gpu inference, you can use the following command: For single-gpu inference, you can use the following command:

BIN
assets/TeaCache4FLUX.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.1 MiB

Binary file not shown.