mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2026-05-02 10:11:19 +08:00
Optimize redundant code
This commit is contained in:
parent
ca1c215ee7
commit
e945259c7d
@ -61,7 +61,7 @@ def teacache_forward_working(
|
|||||||
}
|
}
|
||||||
|
|
||||||
current_cache = self.cache[cache_key]
|
current_cache = self.cache[cache_key]
|
||||||
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone())
|
modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop, temb)
|
||||||
|
|
||||||
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
@ -91,23 +91,35 @@ def teacache_forward_working(
|
|||||||
|
|
||||||
current_cache["previous_modulated_input"] = modulated_inp.clone()
|
current_cache["previous_modulated_input"] = modulated_inp.clone()
|
||||||
|
|
||||||
if not hasattr(self, 'uncond_seq_len'):
|
if self.uncond_seq_len is None:
|
||||||
self.uncond_seq_len = cache_key
|
self.uncond_seq_len = cache_key
|
||||||
if cache_key != self.uncond_seq_len:
|
if cache_key != self.uncond_seq_len:
|
||||||
self.cnt += 1
|
self.cnt += 1
|
||||||
if self.cnt >= self.num_steps:
|
if self.cnt >= self.num_steps:
|
||||||
self.cnt = 0
|
self.cnt = 0
|
||||||
|
|
||||||
if self.enable_teacache and not should_calc:
|
if self.enable_teacache and not should_calc:
|
||||||
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
|
if max_seq_len in self.cache and "previous_residual" in self.cache[max_seq_len] and self.cache[max_seq_len]["previous_residual"] is not None:
|
||||||
else:
|
processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
|
||||||
ori_input = input_to_main_loop.clone()
|
else:
|
||||||
|
should_calc = True
|
||||||
|
current_processing_states = input_to_main_loop
|
||||||
|
for layer in self.layers:
|
||||||
|
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
||||||
|
processed_hidden_states = current_processing_states
|
||||||
|
|
||||||
|
|
||||||
|
if not (self.enable_teacache and not should_calc) :
|
||||||
current_processing_states = input_to_main_loop
|
current_processing_states = input_to_main_loop
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input
|
if max_seq_len in self.cache:
|
||||||
|
self.cache[max_seq_len]["previous_residual"] = current_processing_states - input_to_main_loop
|
||||||
|
else:
|
||||||
|
logger.warning(f"TeaCache: Cache key {max_seq_len} not found when trying to save residual.")
|
||||||
|
|
||||||
processed_hidden_states = current_processing_states
|
processed_hidden_states = current_processing_states
|
||||||
|
|
||||||
output_after_norm = self.norm_out(processed_hidden_states, temb)
|
output_after_norm = self.norm_out(processed_hidden_states, temb)
|
||||||
@ -127,6 +139,9 @@ def teacache_forward_working(
|
|||||||
if USE_PEFT_BACKEND:
|
if USE_PEFT_BACKEND:
|
||||||
unscale_lora_layers(self, lora_scale)
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (final_output_tensor,)
|
||||||
|
|
||||||
return Transformer2DModelOutput(sample=final_output_tensor)
|
return Transformer2DModelOutput(sample=final_output_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +166,7 @@ output_filename = f"teacache_lumina2_output.png"
|
|||||||
pipeline.transformer.__class__.enable_teacache = True
|
pipeline.transformer.__class__.enable_teacache = True
|
||||||
pipeline.transformer.__class__.cnt = 0
|
pipeline.transformer.__class__.cnt = 0
|
||||||
pipeline.transformer.__class__.num_steps = num_inference_steps
|
pipeline.transformer.__class__.num_steps = num_inference_steps
|
||||||
pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
|
pipeline.transformer.__class__.rel_l1_thresh = 0.3
|
||||||
pipeline.transformer.__class__.cache = {}
|
pipeline.transformer.__class__.cache = {}
|
||||||
pipeline.transformer.__class__.uncond_seq_len = None
|
pipeline.transformer.__class__.uncond_seq_len = None
|
||||||
|
|
||||||
@ -164,3 +179,4 @@ image = pipeline(
|
|||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
image.save(output_filename)
|
image.save(output_filename)
|
||||||
|
print(f"Image saved to {output_filename}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user