mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-08 20:34:24 +08:00
Add support for HiDream-I1
This commit is contained in:
parent
73d9573763
commit
d680b3a2df
43
TeaCache4HiDream-I1/README.md
Normal file
43
TeaCache4HiDream-I1/README.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
<!-- ## **TeaCache4HiDream-I1** -->
|
||||||
|
# TeaCache4HiDream-I1
|
||||||
|
|
||||||
|
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-HiDream-I1-Full with various `rel_l1_thresh` values: 0 (original), 0.17 (1.5x speedup), 0.25 (1.7x speedup), 0.3 (2.0x speedup), and 0.45 (2.6x speedup).
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 📈 Inference Latency Comparisons on a Single A100
|
||||||
|
|
||||||
|
| HiDream-I1-Full | TeaCache (0.17) | TeaCache (0.25) | TeaCache (0.3) | TeaCache (0.45) |
|
||||||
|
|:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:--------------------:|
|
||||||
|
| ~50 s | ~34 s | ~29 s | ~25 s | ~19 s |
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install git+https://github.com/huggingface/diffusers
|
||||||
|
pip install --upgrade transformers protobuf tiktoken tokenizers sentencepiece
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
You can modify the `rel_l1_thresh` in line 297 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python teacache_hidream_i1.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{liu2024timestep,
|
||||||
|
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
|
||||||
|
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
|
||||||
|
journal={arXiv preprint arXiv:2411.19108},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
We would like to thank the contributors to the [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Diffusers](https://github.com/huggingface/diffusers).
|
||||||
307
TeaCache4HiDream-I1/teacache_hidream_i1.py
Normal file
307
TeaCache4HiDream-I1/teacache_hidream_i1.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||||
|
from diffusers import HiDreamImagePipeline
|
||||||
|
from diffusers.models import HiDreamImageTransformer2DModel
|
||||||
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||||
|
from diffusers.utils import logging, deprecate, USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def teacache_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timesteps: torch.LongTensor = None,
|
||||||
|
encoder_hidden_states_t5: torch.Tensor = None,
|
||||||
|
encoder_hidden_states_llama3: torch.Tensor = None,
|
||||||
|
pooled_embeds: torch.Tensor = None,
|
||||||
|
img_ids: Optional[torch.Tensor] = None,
|
||||||
|
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||||
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
||||||
|
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
|
||||||
|
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
||||||
|
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
||||||
|
|
||||||
|
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
|
||||||
|
deprecation_message = (
|
||||||
|
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
||||||
|
)
|
||||||
|
deprecate("img_ids", "0.35.0", deprecation_message)
|
||||||
|
|
||||||
|
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
|
||||||
|
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
|
||||||
|
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_kwargs is not None:
|
||||||
|
attention_kwargs = attention_kwargs.copy()
|
||||||
|
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||||
|
else:
|
||||||
|
lora_scale = 1.0
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
else:
|
||||||
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||||
|
)
|
||||||
|
|
||||||
|
# spatial forward
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states_type = hidden_states.dtype
|
||||||
|
|
||||||
|
# Patchify the input
|
||||||
|
if hidden_states_masks is None:
|
||||||
|
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
|
||||||
|
|
||||||
|
# Embed the hidden states
|
||||||
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
|
|
||||||
|
# 0. time
|
||||||
|
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||||
|
p_embedder = self.p_embedder(pooled_embeds)
|
||||||
|
temb = timesteps + p_embedder
|
||||||
|
|
||||||
|
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
|
||||||
|
|
||||||
|
if self.caption_projection is not None:
|
||||||
|
new_encoder_hidden_states = []
|
||||||
|
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
||||||
|
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
||||||
|
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||||
|
new_encoder_hidden_states.append(enc_hidden_state)
|
||||||
|
encoder_hidden_states = new_encoder_hidden_states
|
||||||
|
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
|
||||||
|
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
|
||||||
|
encoder_hidden_states.append(encoder_hidden_states_t5)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
encoder_hidden_states[-1].shape[1]
|
||||||
|
+ encoder_hidden_states[-2].shape[1]
|
||||||
|
+ encoder_hidden_states[0].shape[1],
|
||||||
|
3,
|
||||||
|
device=img_ids.device,
|
||||||
|
dtype=img_ids.dtype,
|
||||||
|
)
|
||||||
|
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
block_id = 0
|
||||||
|
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
||||||
|
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
|
if self.enable_teacache:
|
||||||
|
modulated_inp = timesteps.clone()
|
||||||
|
if self.cnt < self.ret_steps:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
else:
|
||||||
|
rescale_func = np.poly1d(self.coefficients)
|
||||||
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
|
should_calc = False
|
||||||
|
else:
|
||||||
|
should_calc = True
|
||||||
|
self.accumulated_rel_l1_distance = 0
|
||||||
|
self.previous_modulated_input = modulated_inp
|
||||||
|
self.cnt += 1
|
||||||
|
if self.cnt == self.num_steps:
|
||||||
|
self.cnt = 0
|
||||||
|
|
||||||
|
if self.enable_teacache:
|
||||||
|
if not should_calc:
|
||||||
|
hidden_states += self.previous_residual
|
||||||
|
else:
|
||||||
|
# 2. Blocks
|
||||||
|
ori_hidden_states = hidden_states.clone()
|
||||||
|
for bid, block in enumerate(self.double_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
cur_encoder_hidden_states = torch.cat(
|
||||||
|
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
||||||
|
)
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||||
|
block,
|
||||||
|
hidden_states,
|
||||||
|
hidden_states_masks,
|
||||||
|
cur_encoder_hidden_states,
|
||||||
|
temb,
|
||||||
|
image_rotary_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, initial_encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_states_masks=hidden_states_masks,
|
||||||
|
encoder_hidden_states=cur_encoder_hidden_states,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
image_tokens_seq_len = hidden_states.shape[1]
|
||||||
|
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||||
|
hidden_states_seq_len = hidden_states.shape[1]
|
||||||
|
if hidden_states_masks is not None:
|
||||||
|
encoder_attention_mask_ones = torch.ones(
|
||||||
|
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||||
|
device=hidden_states_masks.device,
|
||||||
|
dtype=hidden_states_masks.dtype,
|
||||||
|
)
|
||||||
|
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
|
||||||
|
|
||||||
|
for bid, block in enumerate(self.single_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
hidden_states = self._gradient_checkpointing_func(
|
||||||
|
block,
|
||||||
|
hidden_states,
|
||||||
|
hidden_states_masks,
|
||||||
|
None,
|
||||||
|
temb,
|
||||||
|
image_rotary_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_states_masks=hidden_states_masks,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||||
|
self.previous_residual = hidden_states - ori_hidden_states
|
||||||
|
else:
|
||||||
|
for bid, block in enumerate(self.double_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
cur_encoder_hidden_states = torch.cat(
|
||||||
|
[initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
|
||||||
|
)
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
|
||||||
|
block,
|
||||||
|
hidden_states,
|
||||||
|
hidden_states_masks,
|
||||||
|
cur_encoder_hidden_states,
|
||||||
|
temb,
|
||||||
|
image_rotary_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, initial_encoder_hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_states_masks=hidden_states_masks,
|
||||||
|
encoder_hidden_states=cur_encoder_hidden_states,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
image_tokens_seq_len = hidden_states.shape[1]
|
||||||
|
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
||||||
|
hidden_states_seq_len = hidden_states.shape[1]
|
||||||
|
if hidden_states_masks is not None:
|
||||||
|
encoder_attention_mask_ones = torch.ones(
|
||||||
|
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
||||||
|
device=hidden_states_masks.device,
|
||||||
|
dtype=hidden_states_masks.dtype,
|
||||||
|
)
|
||||||
|
hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
|
||||||
|
|
||||||
|
for bid, block in enumerate(self.single_stream_blocks):
|
||||||
|
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
||||||
|
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
hidden_states = self._gradient_checkpointing_func(
|
||||||
|
block,
|
||||||
|
hidden_states,
|
||||||
|
hidden_states_masks,
|
||||||
|
None,
|
||||||
|
temb,
|
||||||
|
image_rotary_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_states_masks=hidden_states_masks,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
||||||
|
|
||||||
|
output = self.final_layer(hidden_states, temb)
|
||||||
|
output = self.unpatchify(output, img_sizes, self.training)
|
||||||
|
if hidden_states_masks is not None:
|
||||||
|
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (output,)
|
||||||
|
return Transformer2DModelOutput(sample=output)
|
||||||
|
|
||||||
|
HiDreamImageTransformer2DModel.forward = teacache_forward
|
||||||
|
num_inference_steps = 50
|
||||||
|
seed = 42
|
||||||
|
prompt = 'A cat holding a sign that says "Hi-Dreams.ai".'
|
||||||
|
|
||||||
|
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
|
text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=True,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = HiDreamImagePipeline.from_pretrained(
|
||||||
|
"HiDream-ai/HiDream-I1-Full",
|
||||||
|
tokenizer_4=tokenizer_4,
|
||||||
|
text_encoder_4=text_encoder_4,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
# pipeline.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
|
||||||
|
|
||||||
|
# TeaCache
|
||||||
|
pipeline.transformer.__class__.enable_teacache = True
|
||||||
|
pipeline.transformer.__class__.cnt = 0
|
||||||
|
pipeline.transformer.__class__.num_steps = num_inference_steps
|
||||||
|
pipeline.transformer.__class__.ret_steps = num_inference_steps * 0.1
|
||||||
|
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.17 for 1.5x speedup, 0.25 for 1.7x speedup, 0.3 for 2x speedup, 0.45 for 2.6x speedup
|
||||||
|
pipeline.transformer.__class__.coefficients = [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01]
|
||||||
|
|
||||||
|
pipeline.to("cuda")
|
||||||
|
img = pipeline(
|
||||||
|
prompt,
|
||||||
|
guidance_scale=5.0,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
img.save("{}.png".format('TeaCache_' + prompt))
|
||||||
BIN
assets/TeaCache4HiDream-I1.png
Normal file
BIN
assets/TeaCache4HiDream-I1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.7 MiB |
Loading…
x
Reference in New Issue
Block a user