mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 07:02:14 +08:00
Fix cache write
This commit is contained in:
parent
d5fb1c20c1
commit
620e7646d3
@ -5,16 +5,6 @@ import jax.numpy as jnp
|
|||||||
_PAD_SLOT_ID = -1
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
def write_to_kv_cache(
|
|
||||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
||||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
|
||||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
|
||||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
|
||||||
) -> jax.Array:
|
|
||||||
f = _write_to_kv_cache
|
|
||||||
return f(key, value, kv_cache, slot_mapping)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_to_kv_cache(
|
def _write_to_kv_cache(
|
||||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
@ -33,7 +23,7 @@ def _write_to_kv_cache(
|
|||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
def _write_to_kv_cache_in_place(
|
def write_to_kv_cache(
|
||||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user