Merge pull request #1 from ali-vilab/main

1
This commit is contained in:
spawner 2025-05-26 07:54:57 +08:00 committed by GitHub
commit ca1c215ee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 356 additions and 2 deletions

View File

@ -64,6 +64,8 @@ We introduce Timestep Embedding Aware Cache (TeaCache), a training-free caching
## 🔥 Latest News
- **If you like our project, please give us a star ⭐ on GitHub for the latest update.**
- [2025/05/26] 🔥 Support [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0). Thanks [@spawner1145](https://github.com/spawner1145).
- [2025/05/25] 🔥 Support [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1). Thanks [@YunjieYu](https://github.com/YunjieYu).
- [2025/04/14] 🔥 Update coefficients of CogVideoX1.5. Thanks [@zishen-ucap](https://github.com/zishen-ucap).
- [2025/04/05] 🎉 Recommended as a **highlight** in CVPR 2025, top 16.8% in accepted papers and top 3.7% in all papers.
- [2025/03/13] 🔥 Optimized TeaCache for [Wan2.1](https://github.com/Wan-Video/Wan2.1). Thanks [@zishen-ucap](https://github.com/zishen-ucap).
@ -137,6 +139,8 @@ If you develop/use TeaCache in your projects and you would like more people to s
- EasyAnimate, see [here](https://github.com/aigc-apps/EasyAnimate).
**Text to Image**
- [TeaCache4Lumina2](./TeaCache4Lumina2/README.md)
- [TeaCache4HiDream-I1](./TeaCache4HiDream-I1/README.md)
- [TeaCache4FLUX](./TeaCache4FLUX/README.md)
- [TeaCache4Lumina-T2X](./TeaCache4Lumina-T2X/README.md)
@ -150,12 +154,12 @@ If you develop/use TeaCache in your projects and you would like more people to s
## 💐 Acknowledgement
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1). Thanks for their contributions!
This repository is built based on [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos), [Wan2.1](https://github.com/Wan-Video/Wan2.1), [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0). Thanks for their contributions!
## 🔒 License
* The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos) and [Wan2.1](https://github.com/Wan-Video/Wan2.1), please follow their LICENSE.
* For [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys), [Diffusers](https://github.com/huggingface/diffusers), [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo), [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [ConsisID](https://github.com/PKU-YuanGroup/ConsisID), [FLUX](https://github.com/black-forest-labs/flux), [Mochi](https://github.com/genmoai/mochi), [LTX-Video](https://github.com/Lightricks/LTX-Video), [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X), [TangoFlux](https://github.com/declare-lab/TangoFlux), [Cosmos](https://github.com/NVIDIA/Cosmos), [Wan2.1](https://github.com/Wan-Video/Wan2.1), [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0), please follow their LICENSE.
* The service is a research preview. Please contact us if you find any potential violations. (liufeng20@mails.ucas.ac.cn)
## 📖 Citation

View File

@ -0,0 +1,43 @@
<!-- ## **TeaCache4HiDream-I1** -->
# TeaCache4HiDream-I1
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-HiDream-I1-Full with various `rel_l1_thresh` values: 0 (original), 0.17 (1.5x speedup), 0.25 (1.7x speedup), 0.3 (2.0x speedup), and 0.45 (2.6x speedup).
![visualization](../assets/TeaCache4HiDream-I1.png)
## 📈 Inference Latency Comparisons on a Single A100
| HiDream-I1-Full | TeaCache (0.17) | TeaCache (0.25) | TeaCache (0.3) | TeaCache (0.45) |
|:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:--------------------:|
| ~50 s | ~34 s | ~29 s | ~25 s | ~19 s |
## Installation
```shell
pip install git+https://github.com/huggingface/diffusers
pip install --upgrade transformers protobuf tiktoken tokenizers sentencepiece
```
## Usage
You can modify the `rel_l1_thresh` in line 297 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_hidream_i1.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,307 @@
from typing import Any, Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import HiDreamImagePipeline
from diffusers.models import HiDreamImageTransformer2DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import logging, deprecate, USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
import torch
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.35.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# Patchify the input
if hidden_states_masks is None:
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
# Embed the hidden states
hidden_states = self.x_embedder(hidden_states)
# 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(encoder_hidden_states_t5)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1]
+ encoder_hidden_states[-2].shape[1]
+ encoder_hidden_states[0].shape[1],
3,
device=img_ids.device,
dtype=img_ids.dtype,
)
ids = torch.cat((img_ids, txt_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
if self.enable_teacache:
modulated_inp = timesteps.clone()
if self.cnt < self.ret_steps:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
# 2. Blocks
ori_hidden_states = hidden_states.clone()
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
cur_encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if hidden_states_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=hidden_states_masks.device,
dtype=hidden_states_masks.dtype,
)
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
None,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
cur_encoder_hidden_states = torch.cat(
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
cur_encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states, initial_encoder_hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=cur_encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if hidden_states_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=hidden_states_masks.device,
dtype=hidden_states_masks.dtype,
)
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
hidden_states_masks,
None,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
encoder_hidden_states=None,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, temb)
output = self.unpatchify(output, img_sizes, self.training)
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
HiDreamImageTransformer2DModel.forward = teacache_forward
num_inference_steps = 50
seed = 42
prompt = 'A cat holding a sign that says "Hi-Dreams.ai".'
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipeline = HiDreamImagePipeline.from_pretrained(
"HiDream-ai/HiDream-I1-Full",
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
)
# pipeline.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.ret_steps = num_inference_steps * 0.1
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.17 for 1.5x speedup, 0.25 for 1.7x speedup, 0.3 for 2x speedup, 0.45 for 2.6x speedup
pipeline.transformer.__class__.coefficients = [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01]
pipeline.to("cuda")
img = pipeline(
prompt,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
img.save("{}.png".format('TeaCache_' + prompt))

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 MiB