Optimize redundant code

This commit is contained in:
spawner 2025-06-04 21:15:31 +08:00 committed by GitHub
parent ca1c215ee7
commit e945259c7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,7 +61,7 @@ def teacache_forward_working(
} }
current_cache = self.cache[cache_key] current_cache = self.cache[cache_key]
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone()) modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb)
if self.cnt == 0 or self.cnt == self.num_steps - 1: if self.cnt == 0 or self.cnt == self.num_steps - 1:
should_calc = True should_calc = True
@ -91,23 +91,35 @@ def teacache_forward_working(
current_cache["previous_modulated_input"] = modulated_inp.clone() current_cache["previous_modulated_input"] = modulated_inp.clone()
if not hasattr(self, 'uncond_seq_len'): if self.uncond_seq_len is None:
self.uncond_seq_len = cache_key self.uncond_seq_len = cache_key
if cache_key != self.uncond_seq_len: if cache_key != self.uncond_seq_len:
self.cnt += 1 self.cnt += 1
if self.cnt >= self.num_steps: if self.cnt >= self.num_steps:
self.cnt = 0 self.cnt = 0
if self.enable_teacache and not should_calc: if self.enable_teacache and not should_calc:
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"] if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None:
else: processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
ori_input = input_to_main_loop.clone() else:
should_calc = True
current_processing_states = input_to_main_loop
for layer in self.layers:
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
processed_hidden_states = current_processing_states
if not (self.enable_teacache and not should_calc) :
current_processing_states = input_to_main_loop current_processing_states = input_to_main_loop
for layer in self.layers: for layer in self.layers:
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
if self.enable_teacache: if self.enable_teacache:
self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input if max_seq_len in self.cache:
self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop
else:
logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.")
processed_hidden_states = current_processing_states processed_hidden_states = current_processing_states
output_after_norm = self.norm_out(processed_hidden_states, temb) output_after_norm = self.norm_out(processed_hidden_states, temb)
@ -127,6 +139,9 @@ def teacache_forward_working(
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale) unscale_lora_layers(self, lora_scale)
if not return_dict:
return (final_output_tensor,)
return Transformer2DModelOutput(sample=final_output_tensor) return Transformer2DModelOutput(sample=final_output_tensor)
@ -151,7 +166,7 @@ output_filename = f"teacache_lumina2_output.png"
pipeline.transformer.__class__.enable_teacache = True pipeline.transformer.__class__.enable_teacache = True
pipeline.transformer.__class__.cnt = 0 pipeline.transformer.__class__.cnt = 0
pipeline.transformer.__class__.num_steps = num_inference_steps pipeline.transformer.__class__.num_steps = num_inference_steps
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup pipeline.transformer.__class__.rel_l1_thresh = 0.3
pipeline.transformer.__class__.cache = {} pipeline.transformer.__class__.cache = {}
pipeline.transformer.__class__.uncond_seq_len = None pipeline.transformer.__class__.uncond_seq_len = None
@ -164,3 +179,4 @@ image = pipeline(
).images[0] ).images[0]
image.save(output_filename) image.save(output_filename)
print(f"Image saved to {output_filename}")