From 6661c030c4f4ee84884e879b48c60feeb0c702d5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:02:00 +0000 Subject: [PATCH] Add paged_attn op --- .../models/jax/ops/paged_attn.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 vllm/model_executor/models/jax/ops/paged_attn.py diff --git a/vllm/model_executor/models/jax/ops/paged_attn.py b/vllm/model_executor/models/jax/ops/paged_attn.py new file mode 100644 index 0000000000000..ab751bf1ca86e --- /dev/null +++ b/vllm/model_executor/models/jax/ops/paged_attn.py @@ -0,0 +1,21 @@ +import jax +from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention + + +def paged_attn( + q: jax.Array, # [batch, 1, num_heads, head_size] + k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] + v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] + block_tables: jax.Array, # [batch, max_num_blocks_per_batch] + context_lens: jax.Array, # [batch] +) -> jax.Array: # [batch, 1, num_heads, head_size] + q = q.squeeze(1) + output = paged_attention( + q, + k_cache, + v_cache, + context_lens, + block_tables, + pages_per_compute_block=4, + ) + return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])