mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:47:22 +08:00
Write kV
This commit is contained in:
parent
756c4e78d3
commit
ef762cb110
@ -20,6 +20,7 @@ from transformers import GemmaConfig
|
||||
|
||||
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
||||
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
||||
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache
|
||||
|
||||
K_MASK = -2.3819763e38 # Set to a large negative number.
|
||||
|
||||
@ -155,6 +156,9 @@ class Attention(nn.Module):
|
||||
# Write the incoming keys and values to KV cache.
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
key_cache = write_to_cache(key_proj, key_cache, slot_mapping)
|
||||
value_cache = write_to_cache(value_proj, value_cache, slot_mapping)
|
||||
cache = jnp.stack([key_cache, value_cache])
|
||||
|
||||
if block_tables is None:
|
||||
# Prompt attention.
|
||||
@ -186,8 +190,8 @@ class Attention(nn.Module):
|
||||
# Decode attention.
|
||||
output = paged_attn(
|
||||
query_proj,
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache[0],
|
||||
cache[1],
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user