mirror of
https://git.datalinker.icu/ali-vilab/TeaCache
synced 2026-05-15 17:11:27 +08:00
support cogvideox
This commit is contained in:
parent
329faf52f3
commit
071075bd34
@ -66,28 +66,19 @@ def teacache_forward(
|
|||||||
hidden_states = hidden_states[:, text_seq_length:]
|
hidden_states = hidden_states[:, text_seq_length:]
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
inp = hidden_states.clone()
|
|
||||||
encoder_hidden_states_ = encoder_hidden_states.clone()
|
|
||||||
emb_ = emb.clone()
|
|
||||||
_, modulated_inp, _, _ = self.transformer_blocks[0].norm1(inp, encoder_hidden_states_, emb_)
|
|
||||||
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
else:
|
else:
|
||||||
if not self.config.use_rotary_positional_embeddings:
|
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
|
||||||
# CogVideoX-2B
|
|
||||||
coefficients = [1.42842830e+05, -3.99193393e+04, 3.85937428e+03, -1.49458838e+02, 2.04751119e+00]
|
|
||||||
else:
|
|
||||||
# CogVideoX-5B
|
|
||||||
coefficients = [1.80221813e+05, -5.37021537e+04, 5.61853221e+03, -2.44280388e+02, 3.83458338e+00]
|
|
||||||
rescale_func = np.poly1d(coefficients)
|
rescale_func = np.poly1d(coefficients)
|
||||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||||
should_calc = False
|
should_calc = False
|
||||||
else:
|
else:
|
||||||
should_calc = True
|
should_calc = True
|
||||||
self.accumulated_rel_l1_distance = 0
|
self.accumulated_rel_l1_distance = 0
|
||||||
self.previous_modulated_input = modulated_inp
|
self.previous_modulated_input = emb
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
if not should_calc:
|
if not should_calc:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user