mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
fix fastercache start step
This commit is contained in:
parent
4c5807397a
commit
66ba4e1ee7
@ -318,20 +318,18 @@ class CogVideoXBlock(nn.Module):
|
||||
del h, fuser
|
||||
|
||||
#fastercache
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0]>=norm_hidden_states.shape[0]:
|
||||
B = norm_hidden_states.shape[0]
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||
attn_hidden_states = (
|
||||
self.cached_hidden_states[1][:norm_hidden_states.shape[0]] +
|
||||
(self.cached_hidden_states[1][:norm_hidden_states.shape[0]] -
|
||||
self.cached_hidden_states[0][:norm_hidden_states.shape[0]])
|
||||
self.cached_hidden_states[1][:B] +
|
||||
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
attn_encoder_hidden_states = (
|
||||
self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] +
|
||||
(self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] -
|
||||
self.cached_encoder_hidden_states[0][:norm_hidden_states.shape[0]])
|
||||
self.cached_encoder_hidden_states[1][:B] +
|
||||
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
@ -339,14 +337,8 @@ class CogVideoXBlock(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [
|
||||
attn_hidden_states.to(fastercache_device),
|
||||
attn_hidden_states.to(fastercache_device)
|
||||
]
|
||||
self.cached_encoder_hidden_states = [
|
||||
attn_encoder_hidden_states.to(fastercache_device),
|
||||
attn_encoder_hidden_states.to(fastercache_device)
|
||||
]
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
@ -769,7 +761,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if self.fastercache_counter>=16:
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
(bb, tt, cc, hh, ww) = output.shape
|
||||
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user