mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-09 21:04:25 +08:00
123 lines
5.3 KiB
Python
123 lines
5.3 KiB
Python
import torch
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
from diffusers import LuminaText2ImgPipeline
|
|
from diffusers.models import LuminaNextDiT2DModel
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
|
import numpy as np
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
def teacache_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
encoder_mask: torch.Tensor,
|
|
image_rotary_emb: torch.Tensor,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
return_dict=True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass of LuminaNextDiT.
|
|
|
|
Parameters:
|
|
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
|
|
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
|
|
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
|
|
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
|
|
"""
|
|
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
|
|
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
|
|
|
|
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
|
|
|
|
encoder_mask = encoder_mask.bool()
|
|
if self.enable_teacache:
|
|
inp = hidden_states.clone()
|
|
temb_ = temb.clone()
|
|
modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_)
|
|
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
else:
|
|
coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344]
|
|
rescale_func = np.poly1d(coefficients)
|
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
|
should_calc = False
|
|
else:
|
|
should_calc = True
|
|
self.accumulated_rel_l1_distance = 0
|
|
self.previous_modulated_input = modulated_inp
|
|
self.cnt += 1
|
|
if self.cnt == self.num_steps:
|
|
self.cnt = 0
|
|
|
|
if self.enable_teacache:
|
|
if not should_calc:
|
|
hidden_states += self.previous_residual
|
|
else:
|
|
ori_hidden_states = hidden_states.clone()
|
|
for layer in self.layers:
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
mask,
|
|
image_rotary_emb,
|
|
encoder_hidden_states,
|
|
encoder_mask,
|
|
temb=temb,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
)
|
|
self.previous_residual = hidden_states - ori_hidden_states
|
|
|
|
else:
|
|
for layer in self.layers:
|
|
hidden_states = layer(
|
|
hidden_states,
|
|
mask,
|
|
image_rotary_emb,
|
|
encoder_hidden_states,
|
|
encoder_mask,
|
|
temb=temb,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
)
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
|
|
# unpatchify
|
|
height_tokens = width_tokens = self.patch_size
|
|
height, width = img_size[0]
|
|
batch_size = hidden_states.size(0)
|
|
sequence_length = (height // height_tokens) * (width // width_tokens)
|
|
hidden_states = hidden_states[:, :sequence_length].view(
|
|
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
|
|
)
|
|
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
|
|
return Transformer2DModelOutput(sample=output)
|
|
|
|
|
|
LuminaNextDiT2DModel.forward = teacache_forward
|
|
num_inference_steps = 30
|
|
seed = 1024
|
|
prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. "
|
|
pipeline = LuminaText2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16).to("cuda")
|
|
|
|
# TeaCache
|
|
pipeline.transformer.__class__.enable_teacache = True
|
|
pipeline.transformer.__class__.cnt = 0
|
|
pipeline.transformer.__class__.num_steps = num_inference_steps
|
|
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
|
|
pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
|
|
pipeline.transformer.__class__.previous_modulated_input = None
|
|
pipeline.transformer.__class__.previous_residual = None
|
|
|
|
image = pipeline(
|
|
prompt=prompt,
|
|
num_inference_steps=num_inference_steps,
|
|
generator=torch.Generator("cpu").manual_seed(seed)
|
|
).images[0]
|
|
image.save("teacache_lumina_{}.png".format(prompt)) |