From 5cb213c85e829914cbe4d805b0b84bcfd0989eb3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:02:28 +0000 Subject: [PATCH] Add flash-attn op --- .../models/jax/ops/flash_attn.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 vllm/model_executor/models/jax/ops/flash_attn.py diff --git a/vllm/model_executor/models/jax/ops/flash_attn.py b/vllm/model_executor/models/jax/ops/flash_attn.py new file mode 100644 index 0000000000000..e87a6a2785e8e --- /dev/null +++ b/vllm/model_executor/models/jax/ops/flash_attn.py @@ -0,0 +1,28 @@ +import jax +from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes, flash_attention + +_DEFAULT_BLOCK_SIZES = { + "block_q": 512, + "block_k_major": 512, + "block_k": 512, + "block_b": 2, +} + +def flash_attn( + q: jax.Array, # [batch, seq_len, num_heads, head_size] + k: jax.Array, # [batch, seq_len, num_heads, head_size] + v: jax.Array, # [batch, seq_len, num_heads, head_size] + sm_scale: float, +) -> jax.Array: # [batch, seq_len, num_heads, head_size] + return flash_attention( + q.transpose(0, 2, 1, 3), + k.transpose(0, 2, 1, 3), + v.transpose(0, 2, 1, 3), + causal=True, + sm_scale=sm_scale, + block_sizes=BlockSizes( + min(_DEFAULT_BLOCK_SIZES["block_q"], q.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_k"], k.shape[1]), + min(_DEFAULT_BLOCK_SIZES["block_b"], q.shape[0])) + ).transpose(0, 2, 1, 3)