diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index a0c976daecff3..e5861d4d2c665 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -165,8 +165,8 @@ class Attention(nn.Module): value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2) key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2) - if False: - # FIXME(woosuk) + if True: + # FlashAttention. output = flash_attn( query_proj, key_proj, @@ -174,6 +174,7 @@ class Attention(nn.Module): self.sm_scale, ) else: + # Naive attention with masking. seq_len = query_proj.shape[1] attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index afb0a3b2ff317..c685dac9ff875 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -81,6 +81,10 @@ class TPUModelRunner: max_prompt_len = max(prompt_lens) assert max_prompt_len > 0 + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + max_prompt_len = (max_prompt_len + 15) // 16 * 16 input_tokens = _make_array_with_pad(input_tokens, max_prompt_len,