Add write_to_cache ops

This commit is contained in:
Woosuk Kwon 2024-04-17 18:20:55 +00:00
parent 4880de35d2
commit 756c4e78d3

View File

@ -0,0 +1,14 @@
import jax
def write_to_cache(
x: jax.Array,
cache: jax.Array,
slot_mapping: jax.Array,
) -> jax.Array:
num_heads, num_blocks, block_size, head_size = cache.shape
cache = cache.reshape(num_heads, num_blocks * block_size, head_size)
x = x.reshape(-1, x.shape[-2], x.shape[-1])
slot_mapping = slot_mapping.reshape(-1)
cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2))
return cache.reshape(num_heads, num_blocks, block_size, head_size)