mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 01:57:12 +08:00
Change version
This commit is contained in:
parent
2aa9831dd3
commit
21f35c2289
@ -7,7 +7,7 @@ import jax.numpy as jnp
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def _write_to_kv_cache(
|
||||
def write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
@ -27,7 +27,7 @@ def _write_to_kv_cache(
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
def _write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user