Change version

This commit is contained in:
Woosuk Kwon 2024-04-26 05:00:26 +00:00
parent 2aa9831dd3
commit 21f35c2289

View File

@ -7,7 +7,7 @@ import jax.numpy as jnp
_PAD_SLOT_ID = -1
def _write_to_kv_cache(
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
@ -27,7 +27,7 @@ def _write_to_kv_cache(
return k_cache, v_cache
def write_to_kv_cache(
def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]