This commit is contained in:
Woosuk Kwon 2024-04-17 18:21:39 +00:00
parent 756c4e78d3
commit ef762cb110

View File

@ -20,6 +20,7 @@ from transformers import GemmaConfig
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache
K_MASK = -2.3819763e38 # Set to a large negative number.
@ -155,6 +156,9 @@ class Attention(nn.Module):
# Write the incoming keys and values to KV cache.
key_cache = cache[0]
value_cache = cache[1]
key_cache = write_to_cache(key_proj, key_cache, slot_mapping)
value_cache = write_to_cache(value_proj, value_cache, slot_mapping)
cache = jnp.stack([key_cache, value_cache])
if block_tables is None:
# Prompt attention.
@ -186,8 +190,8 @@ class Attention(nn.Module):
# Decode attention.
output = paged_attn(
query_proj,
key_cache,
value_cache,
cache[0],
cache[1],
block_tables,
context_lens,
)