From 7e3a230c38ac2ab0fa4bc6694d6f08e471d93bf9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 20:06:26 +0000 Subject: [PATCH] Fix paged_attn --- vllm/model_executor/models/jax/gemma.py | 1 + vllm/model_executor/models/jax/ops/paged_attn.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index 4d98d23c4d2cb..a0c976daecff3 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -190,6 +190,7 @@ class Attention(nn.Module): query_proj, cache[0], cache[1], + self.sm_scale, block_tables, context_lens, ) diff --git a/vllm/model_executor/models/jax/ops/paged_attn.py b/vllm/model_executor/models/jax/ops/paged_attn.py index ab751bf1ca86e..fc53651f9e1a1 100644 --- a/vllm/model_executor/models/jax/ops/paged_attn.py +++ b/vllm/model_executor/models/jax/ops/paged_attn.py @@ -6,10 +6,12 @@ def paged_attn( q: jax.Array, # [batch, 1, num_heads, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] + sm_scale: float, block_tables: jax.Array, # [batch, max_num_blocks_per_batch] context_lens: jax.Array, # [batch] ) -> jax.Array: # [batch, 1, num_heads, head_size] q = q.squeeze(1) + q = q * sm_scale output = paged_attention( q, k_cache,