fix fastercache start step

This commit is contained in:
kijai 2024-10-28 22:52:30 +02:00
parent 4c5807397a
commit 66ba4e1ee7

View File

@ -318,36 +318,28 @@ class CogVideoXBlock(nn.Module):
del h, fuser
#fastercache
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0]>=norm_hidden_states.shape[0]:
B = norm_hidden_states.shape[0]
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
attn_hidden_states = (
self.cached_hidden_states[1][:norm_hidden_states.shape[0]] +
(self.cached_hidden_states[1][:norm_hidden_states.shape[0]] -
self.cached_hidden_states[0][:norm_hidden_states.shape[0]])
self.cached_hidden_states[1][:B] +
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
attn_encoder_hidden_states = (
self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] +
(self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] -
self.cached_encoder_hidden_states[0][:norm_hidden_states.shape[0]])
*0.3
self.cached_encoder_hidden_states[1][:B] +
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
* 0.3
).to(norm_hidden_states.device, non_blocking=True)
else:
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if fastercache_counter==fastercache_start_step:
self.cached_hidden_states = [
attn_hidden_states.to(fastercache_device),
attn_hidden_states.to(fastercache_device)
]
self.cached_encoder_hidden_states = [
attn_encoder_hidden_states.to(fastercache_device),
attn_encoder_hidden_states.to(fastercache_device)
]
elif fastercache_counter>fastercache_start_step:
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
@ -769,7 +761,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if self.fastercache_counter>=16:
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, tt, cc, hh, ww) = output.shape
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)