Fix KV cache shape

This commit is contained in:
Woosuk Kwon 2024-04-25 23:38:07 +00:00
parent fa5bacd5b0
commit 028f528aad
4 changed files with 77 additions and 49 deletions

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Gemma transformer."""
from typing import List, Tuple
import jax
import jax.numpy as jnp
from flax import linen as nn
@ -134,7 +136,7 @@ class Attention(nn.Module):
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
cache: jax.Array,
cache: Tuple[jax.Array, jax.Array],
) -> tuple[jax.Array, jax.Array]:
if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
@ -154,7 +156,9 @@ class Attention(nn.Module):
)
# Write the incoming keys and values to KV cache.
cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping)
k_cache, v_cache = cache
k_cache, v_cache = write_to_kv_cache(
key_proj, value_proj, k_cache, v_cache, slot_mapping)
if block_tables is None:
# Prompt attention.
@ -187,15 +191,15 @@ class Attention(nn.Module):
# Decode attention.
output = paged_attn(
query_proj,
cache[0],
cache[1],
k_cache,
v_cache,
self.sm_scale,
block_tables,
context_lens,
)
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
return cache, attn_output
return (k_cache, v_cache), attn_output
class FeedForward(nn.Module):
@ -253,8 +257,8 @@ class Block(nn.Module):
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
cache: jax.Array,
) -> tuple[jax.Array, jax.Array]:
cache: Tuple[jax.Array, jax.Array],
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
@ -302,9 +306,9 @@ class Transformer(nn.Module):
slot_mapping: jax.Array,
block_tables: jax.Array | None,
context_lens: jax.Array | None,
kv_caches: list[jax.Array],
kv_caches: List[Tuple[jax.Array, jax.Array]],
logits_indices: jax.Array,
) -> tuple[jax.Array, list[jax.Array]]:
) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
x = self.embedder.encode(token_ids)
new_caches = []
for i, block in enumerate(self.blocks):

View File

@ -4,20 +4,27 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int = 16, # FIXME(woosuk)
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
output = paged_attention(
q,
k_cache,
v_cache,
context_lens,
block_tables,
pages_per_compute_block=4,
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])

View File

@ -8,57 +8,65 @@ _PAD_SLOT_ID = -1
def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
"""Out-of-place write to KV cache."""
num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:]
key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size]
key_value = key_value.reshape(2, -1, num_heads, head_size)
key_value = key_value.transpose((0, 2, 1, 3))
) -> Tuple[jax.Array, jax.Array]:
num_heads = key.shape[-2]
head_size = key.shape[-1]
kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size)
kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value)
kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size)
return kv_cache
key = key.reshape(-1, num_heads, head_size)
key = key.transpose((1, 0, 2))
value = value.reshape(-1, num_heads, head_size)
value = value.transpose((1, 0, 2))
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
return k_cache, v_cache
def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array:
"""In-place write to KV cache."""
) -> Tuple[jax.Array, jax.Array]:
batch_size = slot_mapping.shape[0]
key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size]
def cond(val: _IteratorState):
return val.idx < batch_size
def body(val: _IteratorState):
val.kv_cache = _write_seq_to_kv_cache(
key_value[val.idx],
val.kv_cache,
k_cache, v_cache = _write_seq_to_kv_cache(
key[val.idx],
value[val.idx],
val.k_cache,
val.v_cache,
slot_mapping[val.idx],
)
val.k_cache = k_cache
val.v_cache = v_cache
val.idx += 1
return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache
return iterator.k_cache, iterator.v_cache
def _write_seq_to_kv_cache(
key_value: jax.Array, # [seq_len, 2, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [seq_len]
) -> jax.Array:
) -> Tuple[jax.Array, jax.Array]:
seq_len = slot_mapping.shape[0]
num_heads, _, block_size, head_size = kv_cache.shape[1:]
num_heads, _, head_size = k_cache.shape
# Reshape to match the rank of kv_cache.
key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size)
key = key.reshape(seq_len, num_heads, 1, head_size)
value = value.reshape(seq_len, num_heads, 1, head_size)
def cond(val: _IteratorState):
return jnp.logical_and(
@ -66,21 +74,27 @@ def _write_seq_to_kv_cache(
def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx]
val.kv_cache = jax.lax.dynamic_update_slice(
val.kv_cache,
key_value[val.idx],
(0, 0, slot_idx // block_size, slot_idx % block_size, 0),
val.k_cache = jax.lax.dynamic_update_slice(
val.k_cache,
key[val.idx],
(0, slot_idx, 0),
)
val.v_cache = jax.lax.dynamic_update_slice(
val.v_cache,
value[val.idx],
(0, slot_idx, 0),
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache
return iterator.k_cache, iterator.v_cache
@chex.dataclass
class _IteratorState:
idx: jnp.int32
kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size]
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]

View File

@ -80,11 +80,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.tpu_cache = [
jnp.zeros(
(2, num_kv_heads, num_gpu_blocks, self.block_size, head_size),
dtype=dtype) for _ in range(num_layers)
]
self.tpu_cache = []
for _ in range(num_layers):
key_cache = jnp.zeros(
(num_kv_heads, num_gpu_blocks * self.block_size, head_size),
dtype=dtype)
value_cache = jnp.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
self.model_runner.block_size = self.block_size
self._warmup_model()