Update asymm_models_joint.py
This commit is contained in:
parent
d29e95d707
commit
1811b7b6c5
@ -417,11 +417,9 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
print("caching attention")
|
|
||||||
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
|
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
|
||||||
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
|
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
|
||||||
elif fastercache_counter > fastercache_start_step:
|
elif fastercache_counter > fastercache_start_step:
|
||||||
print("updating attention")
|
|
||||||
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||||
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user