From 756c4e78d31fe840c5b2a7e30a6f0d3208afc5ca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:20:55 +0000 Subject: [PATCH] Add write_to_cache ops --- .../models/jax/ops/write_to_cache.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 vllm/model_executor/models/jax/ops/write_to_cache.py diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py new file mode 100644 index 0000000000000..cd54e90bafe16 --- /dev/null +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -0,0 +1,14 @@ +import jax + + +def write_to_cache( + x: jax.Array, + cache: jax.Array, + slot_mapping: jax.Array, +) -> jax.Array: + num_heads, num_blocks, block_size, head_size = cache.shape + cache = cache.reshape(num_heads, num_blocks * block_size, head_size) + x = x.reshape(-1, x.shape[-2], x.shape[-1]) + slot_mapping = slot_mapping.reshape(-1) + cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2)) + return cache.reshape(num_heads, num_blocks, block_size, head_size)