From ef762cb1100f9e931f053bbfbc9df1bfedb79214 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:21:39 +0000 Subject: [PATCH] Write kV --- vllm/model_executor/models/jax/gemma.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index a244aaa9c94ff..da3d94a4c1dd7 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -20,6 +20,7 @@ from transformers import GemmaConfig from vllm.model_executor.models.jax.ops.flash_attn import flash_attn from vllm.model_executor.models.jax.ops.paged_attn import paged_attn +from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache K_MASK = -2.3819763e38 # Set to a large negative number. @@ -155,6 +156,9 @@ class Attention(nn.Module): # Write the incoming keys and values to KV cache. key_cache = cache[0] value_cache = cache[1] + key_cache = write_to_cache(key_proj, key_cache, slot_mapping) + value_cache = write_to_cache(value_proj, value_cache, slot_mapping) + cache = jnp.stack([key_cache, value_cache]) if block_tables is None: # Prompt attention. @@ -186,8 +190,8 @@ class Attention(nn.Module): # Decode attention. output = paged_attn( query_proj, - key_cache, - value_cache, + cache[0], + cache[1], block_tables, context_lens, )