Update asymm_models_joint.py

This commit is contained in:
kijai 2024-11-05 08:51:59 +02:00
parent 1811b7b6c5
commit 04d15b64ae

View File

@ -396,7 +396,6 @@ class AsymmetricJointBlock(nn.Module):
B = x.shape[0] B = x.shape[0]
#print("x", x.shape) #([1, 9540, 3072]) #print("x", x.shape) #([1, 9540, 3072])
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B: if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
print("using fastercache")
x_attn = ( x_attn = (
self.cached_x_attention[1][:B] + self.cached_x_attention[1][:B] +
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B]) (self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])