support cogvideox

This commit is contained in:
LiewFeng 2024-12-19 20:09:47 +08:00
parent 329faf52f3
commit 071075bd34

View File

@ -66,28 +66,19 @@ def teacache_forward(
hidden_states = hidden_states[:, text_seq_length:]
if self.enable_teacache:
inp = hidden_states.clone()
encoder_hidden_states_ = encoder_hidden_states.clone()
emb_ = emb.clone()
_, modulated_inp, _, _ = self.transformer_blocks[0].norm1(inp, encoder_hidden_states_, emb_)
if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
coefficients = [1.42842830e+05, -3.99193393e+04, 3.85937428e+03, -1.49458838e+02, 2.04751119e+00]
else:
# CogVideoX-5B
coefficients = [1.80221813e+05, -5.37021537e+04, 5.61853221e+03, -2.44280388e+02, 3.83458338e+00]
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.previous_modulated_input = emb
if self.enable_teacache:
if not should_calc: