mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 18:27:22 +08:00
Write kV
This commit is contained in:
parent
756c4e78d3
commit
ef762cb110
@ -20,6 +20,7 @@ from transformers import GemmaConfig
|
|||||||
|
|
||||||
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
||||||
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
||||||
|
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache
|
||||||
|
|
||||||
K_MASK = -2.3819763e38 # Set to a large negative number.
|
K_MASK = -2.3819763e38 # Set to a large negative number.
|
||||||
|
|
||||||
@ -155,6 +156,9 @@ class Attention(nn.Module):
|
|||||||
# Write the incoming keys and values to KV cache.
|
# Write the incoming keys and values to KV cache.
|
||||||
key_cache = cache[0]
|
key_cache = cache[0]
|
||||||
value_cache = cache[1]
|
value_cache = cache[1]
|
||||||
|
key_cache = write_to_cache(key_proj, key_cache, slot_mapping)
|
||||||
|
value_cache = write_to_cache(value_proj, value_cache, slot_mapping)
|
||||||
|
cache = jnp.stack([key_cache, value_cache])
|
||||||
|
|
||||||
if block_tables is None:
|
if block_tables is None:
|
||||||
# Prompt attention.
|
# Prompt attention.
|
||||||
@ -186,8 +190,8 @@ class Attention(nn.Module):
|
|||||||
# Decode attention.
|
# Decode attention.
|
||||||
output = paged_attn(
|
output = paged_attn(
|
||||||
query_proj,
|
query_proj,
|
||||||
key_cache,
|
cache[0],
|
||||||
value_cache,
|
cache[1],
|
||||||
block_tables,
|
block_tables,
|
||||||
context_lens,
|
context_lens,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user