Update asymm_models_joint.py
This commit is contained in:
parent
d29e95d707
commit
1811b7b6c5
@ -417,11 +417,9 @@ class AsymmetricJointBlock(nn.Module):
|
||||
**attn_kwargs,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
print("caching attention")
|
||||
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
|
||||
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
print("updating attention")
|
||||
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user