Fix cache write

This commit is contained in:
Woosuk Kwon 2024-04-24 08:56:30 +00:00
parent d5fb1c20c1
commit 620e7646d3

View File

@ -5,16 +5,6 @@ import jax.numpy as jnp
_PAD_SLOT_ID = -1
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
f = _write_to_kv_cache
return f(key, value, kv_cache, slot_mapping)
def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
@ -33,7 +23,7 @@ def _write_to_kv_cache(
return kv_cache
def _write_to_kv_cache_in_place(
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]