mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 15:46:59 +08:00
Add write_to_cache ops
This commit is contained in:
parent
4880de35d2
commit
756c4e78d3
14
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
14
vllm/model_executor/models/jax/ops/write_to_cache.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_cache(
|
||||||
|
x: jax.Array,
|
||||||
|
cache: jax.Array,
|
||||||
|
slot_mapping: jax.Array,
|
||||||
|
) -> jax.Array:
|
||||||
|
num_heads, num_blocks, block_size, head_size = cache.shape
|
||||||
|
cache = cache.reshape(num_heads, num_blocks * block_size, head_size)
|
||||||
|
x = x.reshape(-1, x.shape[-2], x.shape[-1])
|
||||||
|
slot_mapping = slot_mapping.reshape(-1)
|
||||||
|
cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2))
|
||||||
|
return cache.reshape(num_heads, num_blocks, block_size, head_size)
|
||||||
Loading…
x
Reference in New Issue
Block a user