Merge pull request #70 from spawner1145/main

support for lumina2
This commit is contained in:
Feng Liu 2025-05-25 23:59:36 +08:00 committed by GitHub
commit 9caba2ff26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 216 additions and 0 deletions

View File

@ -0,0 +1,50 @@
<!-- ## **TeaCache4LuminaT2X** -->
# TeaCache4Lumina2
[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Image-2.0 with various rel_l1_thresh values: 0 (original), 0.2 (1.25x speedup), 0.3 (1.5625x speedup), 0.4 (2.0833x speedup), 0.5 (2.5x speedup).
<p align="center">
<img src="https://github.com/user-attachments/assets/d2c87b99-e4ac-4407-809a-caf9750f41ef" width="150" style="margin: 5px;">
<img src="https://github.com/user-attachments/assets/411ff763-9c31-438d-8a9b-3ec5c88f6c27" width="150" style="margin: 5px;">
<img src="https://github.com/user-attachments/assets/e57dfb60-a07f-4e17-837e-e46a69d8b9c0" width="150" style="margin: 5px;">
<img src="https://github.com/user-attachments/assets/6e3184fe-e31a-452c-a447-48d4b74fcc10" width="150" style="margin: 5px;">
<img src="https://github.com/user-attachments/assets/d6a52c4c-bd22-45c0-9f40-00a2daa85fc8" width="150" style="margin: 5px;">
</p>
## 📈 Inference Latency Comparisons on a single 4090 (step 50)
| Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) |
|:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:|
| ~25 s | ~20 s | ~16 s | ~12 s | ~10 s |
## Installation
```shell
pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
pip install flash-attn --no-build-isolation
```
## Usage
You can modify the thresh in line 154 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
```bash
python teacache_lumina2.py
```
## Citation
If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
```
@article{liu2024timestep,
title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
journal={arXiv preprint arXiv:2411.19108},
year={2024}
}
```
## Acknowledgements
We would like to thank the contributors to the [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) and [Diffusers](https://github.com/huggingface/diffusers).

View File

@ -0,0 +1,166 @@
import torch
import torch.nn as nn
import numpy as np
from typing import Any, Dict, Optional, Tuple, Union, List
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward_working(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
batch_size, _, height, width = hidden_states.shape
temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
(image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
image_patch_embeddings = self.x_embedder(image_patch_embeddings)
for layer in self.context_refiner:
encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
for layer in self.noise_refiner:
image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
max_seq_len = max(seq_lengths)
input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
use_mask = len(set(seq_lengths)) > 1
attention_mask_for_main_loop_arg = None
if use_mask:
mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
mask[i, :seq_len_val] = True
attention_mask_for_main_loop_arg = mask
should_calc = True
if self.enable_teacache:
cache_key = max_seq_len
if cache_key not in self.cache:
self.cache[cache_key] = {
"accumulated_rel_l1_distance": 0.0,
"previous_modulated_input": None,
"previous_residual": None,
}
current_cache = self.cache[cache_key]
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone())
if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
else:
if current_cache["previous_modulated_input"] is not None:
coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py
rescale_func = np.poly1d(coefficients)
prev_mod_input = current_cache["previous_modulated_input"]
prev_mean = prev_mod_input.abs().mean()
if prev_mean.item() > 1e-9:
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
else:
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
else:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
current_cache["previous_modulated_input"] = modulated_inp.clone()
if not hasattr(self, 'uncond_seq_len'):
self.uncond_seq_len = cache_key
if cache_key != self.uncond_seq_len:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
if self.enable_teacache and not should_calc:
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
else:
ori_input = input_to_main_loop.clone()
current_processing_states = input_to_main_loop
for layer in self.layers:
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
if self.enable_teacache:
self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input
processed_hidden_states = current_processing_states
output_after_norm = self.norm_out(processed_hidden_states, temb)
p = self.config.patch_size
final_output_list = []
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
image_part = output_after_norm[i][enc_len:seq_len_val]
h_p, w_p = height // p, width // p
reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \
.permute(4, 0, 2, 1, 3) \
.flatten(3, 4) \
.flatten(1, 2)
final_output_list.append(reconstructed_image)
final_output_tensor = torch.stack(final_output_list, dim=0)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return Transformer2DModelOutput(sample=final_output_tensor)
Lumina2Transformer2DModel.forward = teacache_forward_working
ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path, torch_dtype=torch.bfloat16
)
pipeline = Lumina2Pipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
num_inference_steps = 30
seed = 1024
prompt = "a cat holding a sign that says hello"
output_filename = f"teacache_lumina2_output.png"
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
pipeline.transformer.__class__.cache = {}
pipeline.transformer.__class__.uncond_seq_len = None
pipeline.enable_model_cpu_offload()
image = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output_filename)