mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2025-12-08 20:34:24 +08:00
Optimize redundant code
This commit is contained in:
parent
ca1c215ee7
commit
e945259c7d
@ -23,7 +23,7 @@ def teacache_forward_working(
|
|||||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||||
else:
|
else:
|
||||||
lora_scale = 1.0
|
lora_scale = 1.0
|
||||||
if USE_PEFT_BACKEND:
|
if USE_PEFT_BACKEND:
|
||||||
scale_lora_layers(self, lora_scale)
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
batch_size, _, height, width = hidden_states.shape
|
batch_size, _, height, width = hidden_states.shape
|
||||||
@ -31,9 +31,9 @@ def teacache_forward_working(
|
|||||||
(image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
|
(image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
|
||||||
encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
|
encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
|
||||||
image_patch_embeddings = self.x_embedder(image_patch_embeddings)
|
image_patch_embeddings = self.x_embedder(image_patch_embeddings)
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
|
encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
|
image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
|
||||||
|
|
||||||
max_seq_len = max(seq_lengths)
|
max_seq_len = max(seq_lengths)
|
||||||
@ -41,12 +41,12 @@ def teacache_forward_working(
|
|||||||
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
||||||
input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
|
input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
|
||||||
input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
|
input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
|
||||||
|
|
||||||
use_mask = len(set(seq_lengths)) > 1
|
use_mask = len(set(seq_lengths)) > 1
|
||||||
attention_mask_for_main_loop_arg = None
|
attention_mask_for_main_loop_arg = None
|
||||||
if use_mask:
|
if use_mask:
|
||||||
mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||||
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
||||||
mask[i, :seq_len_val] = True
|
mask[i, :seq_len_val] = True
|
||||||
attention_mask_for_main_loop_arg = mask
|
attention_mask_for_main_loop_arg = mask
|
||||||
|
|
||||||
@ -59,9 +59,9 @@ def teacache_forward_working(
|
|||||||
"previous_modulated_input": None,
|
"previous_modulated_input": None,
|
||||||
"previous_residual": None,
|
"previous_residual": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
current_cache = self.cache[cache_key]
|
current_cache = self.cache[cache_key]
|
||||||
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone())
|
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb)
|
||||||
|
|
||||||
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
@ -72,12 +72,12 @@ def teacache_forward_working(
|
|||||||
rescale_func = np.poly1d(coefficients)
|
rescale_func = np.poly1d(coefficients)
|
||||||
prev_mod_input = current_cache["previous_modulated_input"]
|
prev_mod_input = current_cache["previous_modulated_input"]
|
||||||
prev_mean = prev_mod_input.abs().mean()
|
prev_mean = prev_mod_input.abs().mean()
|
||||||
|
|
||||||
if prev_mean.item() > 1e-9:
|
if prev_mean.item() > 1e-9:
|
||||||
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
|
rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
|
||||||
else:
|
else:
|
||||||
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
|
rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
|
||||||
|
|
||||||
current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
|
current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
|
||||||
|
|
||||||
if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
|
if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
|
||||||
@ -85,29 +85,41 @@ def teacache_forward_working(
|
|||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
current_cache["accumulated_rel_l1_distance"] = 0.0
|
current_cache["accumulated_rel_l1_distance"] = 0.0
|
||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
current_cache["accumulated_rel_l1_distance"] = 0.0
|
current_cache["accumulated_rel_l1_distance"] = 0.0
|
||||||
|
|
||||||
current_cache["previous_modulated_input"] = modulated_inp.clone()
|
current_cache["previous_modulated_input"] = modulated_inp.clone()
|
||||||
|
|
||||||
if not hasattr(self, 'uncond_seq_len'):
|
if self.uncond_seq_len is None:
|
||||||
self.uncond_seq_len = cache_key
|
self.uncond_seq_len = cache_key
|
||||||
if cache_key != self.uncond_seq_len:
|
if cache_key != self.uncond_seq_len:
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
if self.cnt >= self.num_steps:
|
if self.cnt >= self.num_steps:
|
||||||
self.cnt = 0
|
self.cnt = 0
|
||||||
|
|
||||||
if self.enable_teacache and not should_calc:
|
if self.enable_teacache and not should_calc:
|
||||||
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
|
if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None:
|
||||||
else:
|
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
|
||||||
ori_input = input_to_main_loop.clone()
|
else:
|
||||||
|
should_calc = True
|
||||||
|
current_processing_states = input_to_main_loop
|
||||||
|
for layer in self.layers:
|
||||||
|
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
||||||
|
processed_hidden_states = current_processing_states
|
||||||
|
|
||||||
|
|
||||||
|
if not (self.enable_teacache and not should_calc) :
|
||||||
current_processing_states = input_to_main_loop
|
current_processing_states = input_to_main_loop
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input
|
if max_seq_len in self.cache:
|
||||||
|
self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop
|
||||||
|
else:
|
||||||
|
logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.")
|
||||||
|
|
||||||
processed_hidden_states = current_processing_states
|
processed_hidden_states = current_processing_states
|
||||||
|
|
||||||
output_after_norm = self.norm_out(processed_hidden_states, temb)
|
output_after_norm = self.norm_out(processed_hidden_states, temb)
|
||||||
@ -123,10 +135,13 @@ def teacache_forward_working(
|
|||||||
final_output_list.append(reconstructed_image)
|
final_output_list.append(reconstructed_image)
|
||||||
|
|
||||||
final_output_tensor = torch.stack(final_output_list, dim=0)
|
final_output_tensor = torch.stack(final_output_list, dim=0)
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
if USE_PEFT_BACKEND:
|
||||||
unscale_lora_layers(self, lora_scale)
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (final_output_tensor,)
|
||||||
|
|
||||||
return Transformer2DModelOutput(sample=final_output_tensor)
|
return Transformer2DModelOutput(sample=final_output_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -137,8 +152,8 @@ transformer = Lumina2Transformer2DModel.from_single_file(
|
|||||||
ckpt_path, torch_dtype=torch.bfloat16
|
ckpt_path, torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
pipeline = Lumina2Pipeline.from_pretrained(
|
pipeline = Lumina2Pipeline.from_pretrained(
|
||||||
"Alpha-VLLM/Lumina-Image-2.0",
|
"Alpha-VLLM/Lumina-Image-2.0",
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
|
||||||
@ -151,9 +166,9 @@ output_filename = f"teacache_lumina2_output.png"
|
|||||||
pipeline.transformer.__class__.enable_teacache = True
|
pipeline.transformer.__class__.enable_teacache = True
|
||||||
pipeline.transformer.__class__.cnt = 0
|
pipeline.transformer.__class__.cnt = 0
|
||||||
pipeline.transformer.__class__.num_steps = num_inference_steps
|
pipeline.transformer.__class__.num_steps = num_inference_steps
|
||||||
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
|
pipeline.transformer.__class__.rel_l1_thresh = 0.3
|
||||||
pipeline.transformer.__class__.cache = {}
|
pipeline.transformer.__class__.cache = {}
|
||||||
pipeline.transformer.__class__.uncond_seq_len = None
|
pipeline.transformer.__class__.uncond_seq_len = None
|
||||||
|
|
||||||
|
|
||||||
pipeline.enable_model_cpu_offload()
|
pipeline.enable_model_cpu_offload()
|
||||||
@ -163,4 +178,5 @@ image = pipeline(
|
|||||||
generator=torch.Generator("cuda").manual_seed(seed)
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
image.save(output_filename)
|
image.save(output_filename)
|
||||||
|
print(f"Image saved to {output_filename}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user