Fix write_to_kv_cache

This commit is contained in:
Woosuk Kwon 2024-04-19 07:51:54 +00:00
parent 62b870fa07
commit 743695f586
2 changed files with 94 additions and 14 deletions

View File

@ -20,7 +20,7 @@ from transformers import GemmaConfig
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache
K_MASK = -2.3819763e38 # Set to a large negative number.
@ -154,9 +154,7 @@ class Attention(nn.Module):
)
# Write the incoming keys and values to KV cache.
key_cache = write_to_cache(key_proj, cache[0], slot_mapping)
value_cache = write_to_cache(value_proj, cache[1], slot_mapping)
cache = jnp.stack([key_cache, value_cache])
cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping)
if block_tables is None:
# Prompt attention.

View File

@ -1,14 +1,96 @@
import chex
import jax
import jax.numpy as jnp
_PAD_SLOT_ID = -1
def write_to_cache(
x: jax.Array,
cache: jax.Array,
slot_mapping: jax.Array,
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
num_heads, num_blocks, block_size, head_size = cache.shape
cache = cache.reshape(num_heads, num_blocks * block_size, head_size)
x = x.reshape(-1, x.shape[-2], x.shape[-1])
slot_mapping = slot_mapping.reshape(-1)
cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2))
return cache.reshape(num_heads, num_blocks, block_size, head_size)
f = _write_to_kv_cache
return f(key, value, kv_cache, slot_mapping)
def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
"""Out-of-place write to KV cache."""
num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:]
key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size]
key_value = key_value.reshape(2, -1, num_heads, head_size)
key_value = key_value.transpose((0, 2, 1, 3))
kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size)
kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value)
kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size)
return kv_cache
def _write_to_kv_cache_in_place(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
"""In-place write to KV cache."""
batch_size = slot_mapping.shape[0]
key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size]
def cond(val: _IteratorState):
return val.idx < batch_size
def body(val: _IteratorState):
val.kv_cache = _write_seq_to_kv_cache(
key_value[val.idx],
val.kv_cache,
slot_mapping[val.idx],
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache
def _write_seq_to_kv_cache(
key_value: jax.Array, # [seq_len, 2, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
slot_mapping: jax.Array, # [seq_len]
) -> jax.Array:
seq_len = slot_mapping.shape[0]
num_heads, _, block_size, head_size = kv_cache.shape[1:]
# Reshape to match the rank of kv_cache.
key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size)
def cond(val: _IteratorState):
return jnp.logical_and(
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx]
val.kv_cache = jax.lax.dynamic_update_slice(
val.kv_cache,
key_value[val.idx],
(0, 0, slot_idx // block_size, slot_idx % block_size, 0),
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache
@chex.dataclass
class _IteratorState:
idx: jnp.int32
kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size]