From 62b870fa0730191ddc004df99efb67681d6f9673 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 20:24:45 +0000 Subject: [PATCH] Use FlashAttention kernel --- vllm/model_executor/models/jax/gemma.py | 5 +++-- vllm/worker/tpu_model_runner.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index a0c976daecff3..e5861d4d2c665 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -165,8 +165,8 @@ class Attention(nn.Module): value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2) key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2) - if False: - # FIXME(woosuk) + if True: + # FlashAttention. output = flash_attn( query_proj, key_proj, @@ -174,6 +174,7 @@ class Attention(nn.Module): self.sm_scale, ) else: + # Naive attention with masking. seq_len = query_proj.shape[1] attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index afb0a3b2ff317..c685dac9ff875 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -81,6 +81,10 @@ class TPUModelRunner: max_prompt_len = max(prompt_lens) assert max_prompt_len > 0 + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + max_prompt_len = (max_prompt_len + 15) // 16 * 16 input_tokens = _make_array_with_pad(input_tokens, max_prompt_len,