TeaCache/TeaCache4Lumina2/teacache_lumina2_v1.py
2025-06-08 17:49:04 +08:00

183 lines
7.6 KiB
Python

import torch
import torch.nn as nn
import numpy as np
from typing import Any, Dict, Optional, Tuple, Union, List
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_forward_working(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
batch_size, _, height, width = hidden_states.shape
temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
(image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
image_patch_embeddings = self.x_embedder(image_patch_embeddings)
for layer in self.context_refiner:
encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
for layer in self.noise_refiner:
image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
max_seq_len = max(seq_lengths)
input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
use_mask = len(set(seq_lengths)) > 1
attention_mask_for_main_loop_arg = None
if use_mask:
mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
mask[i, :seq_len_val] = True
attention_mask_for_main_loop_arg = mask
should_calc = True
if self.enable_teacache:
cache_key = max_seq_len
if cache_key not in self.cache:
self.cache[cache_key] = {
"accumulated_rel_l1_distance": 0.0,
"previous_modulated_input": None,
"previous_residual": None,
}
current_cache = self.cache[cache_key]
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb)
if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
else:
if current_cache["previous_modulated_input"] is not None:
coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py
rescale_func = np.poly1d(coefficients)
prev_mod_input = current_cache["previous_modulated_input"]
prev_mean = prev_mod_input.abs().mean()
if prev_mean.item() > 1e-9:
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
else:
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
else:
should_calc = True
current_cache["accumulated_rel_l1_distance"] = 0.0
current_cache["previous_modulated_input"] = modulated_inp.clone()
if self.uncond_seq_len is None:
self.uncond_seq_len = cache_key
if cache_key != self.uncond_seq_len:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
if self.enable_teacache and not should_calc:
if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None:
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
else:
should_calc = True
current_processing_states = input_to_main_loop
for layer in self.layers:
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
processed_hidden_states = current_processing_states
if not (self.enable_teacache and not should_calc) :
current_processing_states = input_to_main_loop
for layer in self.layers:
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
if self.enable_teacache:
if max_seq_len in self.cache:
self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop
else:
logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.")
processed_hidden_states = current_processing_states
output_after_norm = self.norm_out(processed_hidden_states, temb)
p = self.config.patch_size
final_output_list = []
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
image_part = output_after_norm[i][enc_len:seq_len_val]
h_p, w_p = height // p, width // p
reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \
.permute(4, 0, 2, 1, 3) \
.flatten(3, 4) \
.flatten(1, 2)
final_output_list.append(reconstructed_image)
final_output_tensor = torch.stack(final_output_list, dim=0)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (final_output_tensor,)
return Transformer2DModelOutput(sample=final_output_tensor)
Lumina2Transformer2DModel.forward = teacache_forward_working
ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path, torch_dtype=torch.bfloat16
)
pipeline = Lumina2Pipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
num_inference_steps = 30
seed = 1024
prompt = "a cat holding a sign that says hello"
output_filename = f"teacache_lumina2_output.png"
# TeaCache
pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.rel_l1_thresh = 0.3
pipeline.transformer.__class__.cache = {}
pipeline.transformer.__class__.uncond_seq_len = None
pipeline.enable_model_cpu_offload()
image = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output_filename)
print(f"Image saved to {output_filename}")