Update asymm_models_joint.py

This commit is contained in:
kijai 2024-11-05 08:50:36 +02:00
parent d29e95d707
commit 1811b7b6c5

View File

@ -417,11 +417,9 @@ class AsymmetricJointBlock(nn.Module):
**attn_kwargs,
)
if fastercache_counter == fastercache_start_step:
print("caching attention")
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
print("updating attention")
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))