From 4588c2d970f9eaba5f9c4850ab490f1302e62604 Mon Sep 17 00:00:00 2001 From: spawner Date: Sat, 7 Jun 2025 16:47:53 +0800 Subject: [PATCH 1/9] Update teacache_lumina2.py --- TeaCache4Lumina2/teacache_lumina2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TeaCache4Lumina2/teacache_lumina2.py b/TeaCache4Lumina2/teacache_lumina2.py index 0145d24..021daea 100644 --- a/TeaCache4Lumina2/teacache_lumina2.py +++ b/TeaCache4Lumina2/teacache_lumina2.py @@ -68,7 +68,7 @@ def teacache_forward_working( current_cache["accumulated_rel_l1_distance"] = 0.0 else: if current_cache["previous_modulated_input"] is not None: - coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py + coefficients = [225.7042019806413, -608.8453716535591, 304.1869942338369, 124.21267720116742, -1.4089066892956552] rescale_func = np.poly1d(coefficients) prev_mod_input = current_cache["previous_modulated_input"] prev_mean = prev_mod_input.abs().mean() From 845823eed4db1e22d6f6481b925b1226e2d883c0 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 15:14:21 +0800 Subject: [PATCH 2/9] Update README.md --- TeaCache4Lumina2/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/TeaCache4Lumina2/README.md b/TeaCache4Lumina2/README.md index 70a7671..4fcac36 100644 --- a/TeaCache4Lumina2/README.md +++ b/TeaCache4Lumina2/README.md @@ -1,22 +1,22 @@ # TeaCache4Lumina2 -[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Image-2.0 with various rel_l1_thresh values: 0 (original), 0.2 (1.25x speedup), 0.3 (1.5625x speedup), 0.4 (2.0833x speedup), 0.5 (2.5x speedup). +[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the experimental results of Lumina-Image-2.0 and TeaCache with different versions: Lumina-Image-2.0 (~25 s), TeaCache (0.2) (~16.7 s, 1.5x speedup), TeaCache (0.3) (~15.6 s, 1.6x speedup), TeaCache (0.5) (~13.79 s, 1.8x speedup), TeaCache (1.1) (~11.9 s, 2.1x speedup).

- - - - - + + + + +

## 📈 Inference Latency Comparisons on a single 4090 (step 50) -| Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) | +| Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.5) | TeaCache (1.1) | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:| -| ~25 s | ~20 s | ~16 s | ~12 s | ~10 s | +| ~25 s | ~16.7 s | ~15.6 s | ~13.79 s | ~11.9 s | ## Installation From c9e2d6454c56120c6f4854110ac7db2b493fa942 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 15:45:52 +0800 Subject: [PATCH 3/9] Update README.md --- TeaCache4Lumina2/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/TeaCache4Lumina2/README.md b/TeaCache4Lumina2/README.md index 4fcac36..b4657fc 100644 --- a/TeaCache4Lumina2/README.md +++ b/TeaCache4Lumina2/README.md @@ -3,6 +3,12 @@ [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the experimental results of Lumina-Image-2.0 and TeaCache with different versions: Lumina-Image-2.0 (~25 s), TeaCache (0.2) (~16.7 s, 1.5x speedup), TeaCache (0.3) (~15.6 s, 1.6x speedup), TeaCache (0.5) (~13.79 s, 1.8x speedup), TeaCache (1.1) (~11.9 s, 2.1x speedup). +The original coefficients +`[393.76566581,−603.50993606,209.10239044,−23.00726601,0.86377344]` + exhibit poor quality at low L1 values but perform better with higher L1 settings, though at a slower speed. The new coefficients +`[225.7042019806413,−608.8453716535591,304.1869942338369,124.21267720116742,−1.4089066892956552]` +, however, offer faster computation and better quality at low L1 levels but incur significant feature loss at high L1 values. +

From f7d676521a80be2ccf0de0fc1d2e364b61b343cb Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 17:47:30 +0800 Subject: [PATCH 4/9] Update README.md --- TeaCache4Lumina2/README.md | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/TeaCache4Lumina2/README.md b/TeaCache4Lumina2/README.md index b4657fc..2724d45 100644 --- a/TeaCache4Lumina2/README.md +++ b/TeaCache4Lumina2/README.md @@ -1,14 +1,24 @@ # TeaCache4Lumina2 -[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the experimental results of Lumina-Image-2.0 and TeaCache with different versions: Lumina-Image-2.0 (~25 s), TeaCache (0.2) (~16.7 s, 1.5x speedup), TeaCache (0.3) (~15.6 s, 1.6x speedup), TeaCache (0.5) (~13.79 s, 1.8x speedup), TeaCache (1.1) (~11.9 s, 2.1x speedup). +[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the experimental results of Lumina-Image-2.0 and TeaCache with different versions: v1(0 (original), 0.2 (1.25x speedup), 0.3 (1.5625x speedup), 0.4 (2.0833x speedup), 0.5 (2.5x speedup).) and v2(Lumina-Image-2.0 (~25 s), TeaCache (0.2) (~16.7 s, 1.5x speedup), TeaCache (0.3) (~15.6 s, 1.6x speedup), TeaCache (0.5) (~13.79 s, 1.8x speedup), TeaCache (1.1) (~11.9 s, 2.1x speedup)). -The original coefficients +The v1 coefficients `[393.76566581,−603.50993606,209.10239044,−23.00726601,0.86377344]` - exhibit poor quality at low L1 values but perform better with higher L1 settings, though at a slower speed. The new coefficients + exhibit poor quality at low L1 values but perform better with higher L1 settings, though at a slower speed. The v2 coefficients `[225.7042019806413,−608.8453716535591,304.1869942338369,124.21267720116742,−1.4089066892956552]` , however, offer faster computation and better quality at low L1 levels but incur significant feature loss at high L1 values. +## v1 +

+ + + + + +

+ +## v2

@@ -18,8 +28,12 @@ The original coefficients

## 📈 Inference Latency Comparisons on a single 4090 (step 50) +## v1 +| Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) | +|:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:| +| ~25 s | ~20 s | ~16 s | ~12 s | ~10 s | - +## v2 | Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.5) | TeaCache (1.1) | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:| | ~25 s | ~16.7 s | ~15.6 s | ~13.79 s | ~11.9 s | From 5670dc8e9973edf8d838ff24a741a64f8caa0572 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 17:47:55 +0800 Subject: [PATCH 5/9] Rename teacache_lumina2.py to teacache_lumina2_v2.py --- TeaCache4Lumina2/{teacache_lumina2.py => teacache_lumina2_v2.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename TeaCache4Lumina2/{teacache_lumina2.py => teacache_lumina2_v2.py} (100%) diff --git a/TeaCache4Lumina2/teacache_lumina2.py b/TeaCache4Lumina2/teacache_lumina2_v2.py similarity index 100% rename from TeaCache4Lumina2/teacache_lumina2.py rename to TeaCache4Lumina2/teacache_lumina2_v2.py From 6a470cfade8025c9b794f2a3bbf447b1e838f840 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 17:49:04 +0800 Subject: [PATCH 6/9] Create teacache_lumina2_v1.py --- TeaCache4Lumina2/teacache_lumina2_v1.py | 182 ++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 TeaCache4Lumina2/teacache_lumina2_v1.py diff --git a/TeaCache4Lumina2/teacache_lumina2_v1.py b/TeaCache4Lumina2/teacache_lumina2_v1.py new file mode 100644 index 0000000..0145d24 --- /dev/null +++ b/TeaCache4Lumina2/teacache_lumina2_v1.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import numpy as np +from typing import Any, Dict, Optional, Tuple, Union, List + +from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def teacache_forward_working( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + batch_size, _, height, width = hidden_states.shape + temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) + (image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb, + encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask) + image_patch_embeddings = self.x_embedder(image_patch_embeddings) + for layer in self.context_refiner: + encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb) + for layer in self.noise_refiner: + image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) + + max_seq_len = max(seq_lengths) + input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] + input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] + + use_mask = len(set(seq_lengths)) > 1 + attention_mask_for_main_loop_arg = None + if use_mask: + mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + mask[i, :seq_len_val] = True + attention_mask_for_main_loop_arg = mask + + should_calc = True + if self.enable_teacache: + cache_key = max_seq_len + if cache_key not in self.cache: + self.cache[cache_key] = { + "accumulated_rel_l1_distance": 0.0, + "previous_modulated_input": None, + "previous_residual": None, + } + + current_cache = self.cache[cache_key] + modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb) + + if self.cnt == 0 or self.cnt == self.num_steps - 1: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + if current_cache["previous_modulated_input"] is not None: + coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py + rescale_func = np.poly1d(coefficients) + prev_mod_input = current_cache["previous_modulated_input"] + prev_mean = prev_mod_input.abs().mean() + + if prev_mean.item() > 1e-9: + rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() + else: + rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf') + + current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change) + + if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + else: + should_calc = True + current_cache["accumulated_rel_l1_distance"] = 0.0 + + current_cache["previous_modulated_input"] = modulated_inp.clone() + + if self.uncond_seq_len is None: + self.uncond_seq_len = cache_key + if cache_key != self.uncond_seq_len: + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + + if self.enable_teacache and not should_calc: + if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None: + processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"] + else: + should_calc = True + current_processing_states = input_to_main_loop + for layer in self.layers: + current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) + processed_hidden_states = current_processing_states + + + if not (self.enable_teacache and not should_calc) : + current_processing_states = input_to_main_loop + for layer in self.layers: + current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) + + if self.enable_teacache: + if max_seq_len in self.cache: + self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop + else: + logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.") + + processed_hidden_states = current_processing_states + + output_after_norm = self.norm_out(processed_hidden_states, temb) + p = self.config.patch_size + final_output_list = [] + for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + image_part = output_after_norm[i][enc_len:seq_len_val] + h_p, w_p = height // p, width // p + reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \ + .permute(4, 0, 2, 1, 3) \ + .flatten(3, 4) \ + .flatten(1, 2) + final_output_list.append(reconstructed_image) + + final_output_tensor = torch.stack(final_output_list, dim=0) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (final_output_tensor,) + + return Transformer2DModelOutput(sample=final_output_tensor) + + +Lumina2Transformer2DModel.forward = teacache_forward_working + +ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth" +transformer = Lumina2Transformer2DModel.from_single_file( + ckpt_path, torch_dtype=torch.bfloat16 +) +pipeline = Lumina2Pipeline.from_pretrained( + "Alpha-VLLM/Lumina-Image-2.0", + transformer=transformer, + torch_dtype=torch.bfloat16 +).to("cuda") + +num_inference_steps = 30 +seed = 1024 +prompt = "a cat holding a sign that says hello" +output_filename = f"teacache_lumina2_output.png" + +# TeaCache +pipeline.transformer.__class__.enable_teacache = True +pipeline.transformer.__class__.cnt = 0 +pipeline.transformer.__class__.num_steps = num_inference_steps +pipeline.transformer.__class__.rel_l1_thresh = 0.3 +pipeline.transformer.__class__.cache = {} +pipeline.transformer.__class__.uncond_seq_len = None + + +pipeline.enable_model_cpu_offload() +image = pipeline( + prompt=prompt, + num_inference_steps=num_inference_steps, + generator=torch.Generator("cuda").manual_seed(seed) +).images[0] + +image.save(output_filename) +print(f"Image saved to {output_filename}") From 0a9b0358ca6333da607806540fb7473601e476c4 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 20:29:47 +0800 Subject: [PATCH 7/9] Delete TeaCache4Lumina2/teacache_lumina2_v2.py --- TeaCache4Lumina2/teacache_lumina2_v2.py | 182 ------------------------ 1 file changed, 182 deletions(-) delete mode 100644 TeaCache4Lumina2/teacache_lumina2_v2.py diff --git a/TeaCache4Lumina2/teacache_lumina2_v2.py b/TeaCache4Lumina2/teacache_lumina2_v2.py deleted file mode 100644 index 021daea..0000000 --- a/TeaCache4Lumina2/teacache_lumina2_v2.py +++ /dev/null @@ -1,182 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -from typing import Any, Dict, Optional, Tuple, Union, List - -from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline -from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -def teacache_forward_working( - self, - hidden_states: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_attention_mask: torch.Tensor, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, -) -> Union[torch.Tensor, Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - if USE_PEFT_BACKEND: - scale_lora_layers(self, lora_scale) - - batch_size, _, height, width = hidden_states.shape - temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - (image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb, - encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask) - image_patch_embeddings = self.x_embedder(image_patch_embeddings) - for layer in self.context_refiner: - encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb) - for layer in self.noise_refiner: - image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) - - max_seq_len = max(seq_lengths) - input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size) - for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] - input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] - - use_mask = len(set(seq_lengths)) > 1 - attention_mask_for_main_loop_arg = None - if use_mask: - mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) - for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - mask[i, :seq_len_val] = True - attention_mask_for_main_loop_arg = mask - - should_calc = True - if self.enable_teacache: - cache_key = max_seq_len - if cache_key not in self.cache: - self.cache[cache_key] = { - "accumulated_rel_l1_distance": 0.0, - "previous_modulated_input": None, - "previous_residual": None, - } - - current_cache = self.cache[cache_key] - modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb) - - if self.cnt == 0 or self.cnt == self.num_steps - 1: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 - else: - if current_cache["previous_modulated_input"] is not None: - coefficients = [225.7042019806413, -608.8453716535591, 304.1869942338369, 124.21267720116742, -1.4089066892956552] - rescale_func = np.poly1d(coefficients) - prev_mod_input = current_cache["previous_modulated_input"] - prev_mean = prev_mod_input.abs().mean() - - if prev_mean.item() > 1e-9: - rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() - else: - rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf') - - current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change) - - if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh: - should_calc = False - else: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 - else: - should_calc = True - current_cache["accumulated_rel_l1_distance"] = 0.0 - - current_cache["previous_modulated_input"] = modulated_inp.clone() - - if self.uncond_seq_len is None: - self.uncond_seq_len = cache_key - if cache_key != self.uncond_seq_len: - self.cnt += 1 - if self.cnt >= self.num_steps: - self.cnt = 0 - - if self.enable_teacache and not should_calc: - if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None: - processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"] - else: - should_calc = True - current_processing_states = input_to_main_loop - for layer in self.layers: - current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) - processed_hidden_states = current_processing_states - - - if not (self.enable_teacache and not should_calc) : - current_processing_states = input_to_main_loop - for layer in self.layers: - current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) - - if self.enable_teacache: - if max_seq_len in self.cache: - self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop - else: - logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.") - - processed_hidden_states = current_processing_states - - output_after_norm = self.norm_out(processed_hidden_states, temb) - p = self.config.patch_size - final_output_list = [] - for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): - image_part = output_after_norm[i][enc_len:seq_len_val] - h_p, w_p = height // p, width // p - reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \ - .permute(4, 0, 2, 1, 3) \ - .flatten(3, 4) \ - .flatten(1, 2) - final_output_list.append(reconstructed_image) - - final_output_tensor = torch.stack(final_output_list, dim=0) - - if USE_PEFT_BACKEND: - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (final_output_tensor,) - - return Transformer2DModelOutput(sample=final_output_tensor) - - -Lumina2Transformer2DModel.forward = teacache_forward_working - -ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth" -transformer = Lumina2Transformer2DModel.from_single_file( - ckpt_path, torch_dtype=torch.bfloat16 -) -pipeline = Lumina2Pipeline.from_pretrained( - "Alpha-VLLM/Lumina-Image-2.0", - transformer=transformer, - torch_dtype=torch.bfloat16 -).to("cuda") - -num_inference_steps = 30 -seed = 1024 -prompt = "a cat holding a sign that says hello" -output_filename = f"teacache_lumina2_output.png" - -# TeaCache -pipeline.transformer.__class__.enable_teacache = True -pipeline.transformer.__class__.cnt = 0 -pipeline.transformer.__class__.num_steps = num_inference_steps -pipeline.transformer.__class__.rel_l1_thresh = 0.3 -pipeline.transformer.__class__.cache = {} -pipeline.transformer.__class__.uncond_seq_len = None - - -pipeline.enable_model_cpu_offload() -image = pipeline( - prompt=prompt, - num_inference_steps=num_inference_steps, - generator=torch.Generator("cuda").manual_seed(seed) -).images[0] - -image.save(output_filename) -print(f"Image saved to {output_filename}") From ff6a08389665e2656e514b9cc65b9eaddf2d63d0 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 20:39:07 +0800 Subject: [PATCH 8/9] Update and rename teacache_lumina2_v1.py to teacache_lumina2.py --- .../{teacache_lumina2_v1.py => teacache_lumina2.py} | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename TeaCache4Lumina2/{teacache_lumina2_v1.py => teacache_lumina2.py} (97%) diff --git a/TeaCache4Lumina2/teacache_lumina2_v1.py b/TeaCache4Lumina2/teacache_lumina2.py similarity index 97% rename from TeaCache4Lumina2/teacache_lumina2_v1.py rename to TeaCache4Lumina2/teacache_lumina2.py index 0145d24..d9f181c 100644 --- a/TeaCache4Lumina2/teacache_lumina2_v1.py +++ b/TeaCache4Lumina2/teacache_lumina2.py @@ -68,7 +68,8 @@ def teacache_forward_working( current_cache["accumulated_rel_l1_distance"] = 0.0 else: if current_cache["previous_modulated_input"] is not None: - coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py +# v1 coefficients,you can switch it to [225.7042019806413, -608.8453716535591, 304.1869942338369, 124.21267720116742, -1.4089066892956552] as v2 + coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] rescale_func = np.poly1d(coefficients) prev_mod_input = current_cache["previous_modulated_input"] prev_mean = prev_mod_input.abs().mean() From 78d2f837d5dc5d3040de51ce554200c5d7a0d036 Mon Sep 17 00:00:00 2001 From: spawner Date: Sun, 8 Jun 2025 20:40:57 +0800 Subject: [PATCH 9/9] Update README.md --- TeaCache4Lumina2/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TeaCache4Lumina2/README.md b/TeaCache4Lumina2/README.md index 2724d45..6c0fcdb 100644 --- a/TeaCache4Lumina2/README.md +++ b/TeaCache4Lumina2/README.md @@ -9,6 +9,8 @@ The v1 coefficients `[225.7042019806413,−608.8453716535591,304.1869942338369,124.21267720116742,−1.4089066892956552]` , however, offer faster computation and better quality at low L1 levels but incur significant feature loss at high L1 values. +You can change the value in line 72 to switch versions + ## v1