mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
fix fastercache start step
This commit is contained in:
parent
4c5807397a
commit
66ba4e1ee7
@ -318,20 +318,18 @@ class CogVideoXBlock(nn.Module):
|
|||||||
del h, fuser
|
del h, fuser
|
||||||
|
|
||||||
#fastercache
|
#fastercache
|
||||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0]>=norm_hidden_states.shape[0]:
|
B = norm_hidden_states.shape[0]
|
||||||
|
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||||
attn_hidden_states = (
|
attn_hidden_states = (
|
||||||
self.cached_hidden_states[1][:norm_hidden_states.shape[0]] +
|
self.cached_hidden_states[1][:B] +
|
||||||
(self.cached_hidden_states[1][:norm_hidden_states.shape[0]] -
|
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
|
||||||
self.cached_hidden_states[0][:norm_hidden_states.shape[0]])
|
|
||||||
* 0.3
|
* 0.3
|
||||||
).to(norm_hidden_states.device, non_blocking=True)
|
).to(norm_hidden_states.device, non_blocking=True)
|
||||||
attn_encoder_hidden_states = (
|
attn_encoder_hidden_states = (
|
||||||
self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] +
|
self.cached_encoder_hidden_states[1][:B] +
|
||||||
(self.cached_encoder_hidden_states[1][:norm_hidden_states.shape[0]] -
|
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
|
||||||
self.cached_encoder_hidden_states[0][:norm_hidden_states.shape[0]])
|
|
||||||
* 0.3
|
* 0.3
|
||||||
).to(norm_hidden_states.device, non_blocking=True)
|
).to(norm_hidden_states.device, non_blocking=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
@ -339,14 +337,8 @@ class CogVideoXBlock(nn.Module):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
self.cached_hidden_states = [
|
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||||
attn_hidden_states.to(fastercache_device),
|
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
|
||||||
attn_hidden_states.to(fastercache_device)
|
|
||||||
]
|
|
||||||
self.cached_encoder_hidden_states = [
|
|
||||||
attn_encoder_hidden_states.to(fastercache_device),
|
|
||||||
attn_encoder_hidden_states.to(fastercache_device)
|
|
||||||
]
|
|
||||||
elif fastercache_counter > fastercache_start_step:
|
elif fastercache_counter > fastercache_start_step:
|
||||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||||
@ -769,7 +761,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||||
|
|
||||||
if self.fastercache_counter>=16:
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
(bb, tt, cc, hh, ww) = output.shape
|
(bb, tt, cc, hh, ww) = output.shape
|
||||||
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||||
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user