Fix KV cache shape

This commit is contained in:
Woosuk Kwon 2024-04-25 23:38:07 +00:00
parent fa5bacd5b0
commit 028f528aad
4 changed files with 77 additions and 49 deletions

View File

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Gemma transformer.""" """Gemma transformer."""
from typing import List, Tuple
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
@ -134,7 +136,7 @@ class Attention(nn.Module):
slot_mapping: jax.Array, slot_mapping: jax.Array,
block_tables: jax.Array | None, block_tables: jax.Array | None,
context_lens: jax.Array | None, context_lens: jax.Array | None,
cache: jax.Array, cache: Tuple[jax.Array, jax.Array],
) -> tuple[jax.Array, jax.Array]: ) -> tuple[jax.Array, jax.Array]:
if self.use_qkv_einsum: if self.use_qkv_einsum:
query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x) query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
@ -154,7 +156,9 @@ class Attention(nn.Module):
) )
# Write the incoming keys and values to KV cache. # Write the incoming keys and values to KV cache.
cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping) k_cache, v_cache = cache
k_cache, v_cache = write_to_kv_cache(
key_proj, value_proj, k_cache, v_cache, slot_mapping)
if block_tables is None: if block_tables is None:
# Prompt attention. # Prompt attention.
@ -187,15 +191,15 @@ class Attention(nn.Module):
# Decode attention. # Decode attention.
output = paged_attn( output = paged_attn(
query_proj, query_proj,
cache[0], k_cache,
cache[1], v_cache,
self.sm_scale, self.sm_scale,
block_tables, block_tables,
context_lens, context_lens,
) )
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output) attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
return cache, attn_output return (k_cache, v_cache), attn_output
class FeedForward(nn.Module): class FeedForward(nn.Module):
@ -253,8 +257,8 @@ class Block(nn.Module):
slot_mapping: jax.Array, slot_mapping: jax.Array,
block_tables: jax.Array | None, block_tables: jax.Array | None,
context_lens: jax.Array | None, context_lens: jax.Array | None,
cache: jax.Array, cache: Tuple[jax.Array, jax.Array],
) -> tuple[jax.Array, jax.Array]: ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
inputs_normalized = self.pre_attention_norm(x) inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn( cache, attn_output = self.attn(
inputs_normalized, inputs_normalized,
@ -302,9 +306,9 @@ class Transformer(nn.Module):
slot_mapping: jax.Array, slot_mapping: jax.Array,
block_tables: jax.Array | None, block_tables: jax.Array | None,
context_lens: jax.Array | None, context_lens: jax.Array | None,
kv_caches: list[jax.Array], kv_caches: List[Tuple[jax.Array, jax.Array]],
logits_indices: jax.Array, logits_indices: jax.Array,
) -> tuple[jax.Array, list[jax.Array]]: ) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
x = self.embedder.encode(token_ids) x = self.embedder.encode(token_ids)
new_caches = [] new_caches = []
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):

View File

@ -4,20 +4,27 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
def paged_attn( def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size] q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float, sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch] block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch] context_lens: jax.Array, # [batch]
block_size: int = 16, # FIXME(woosuk)
) -> jax.Array: # [batch, 1, num_heads, head_size] ) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1) q = q.squeeze(1)
q = q * sm_scale q = q * sm_scale
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
output = paged_attention( output = paged_attention(
q, q,
k_cache, k_cache,
v_cache, v_cache,
context_lens, context_lens,
block_tables, block_tables,
pages_per_compute_block=4, pages_per_compute_block=4, # TODO(woosuk): Tune this value.
) )
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])

View File

@ -8,57 +8,65 @@ _PAD_SLOT_ID = -1
def _write_to_kv_cache( def _write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size] key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len] slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array: ) -> Tuple[jax.Array, jax.Array]:
"""Out-of-place write to KV cache.""" num_heads = key.shape[-2]
num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:] head_size = key.shape[-1]
key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size]
key_value = key_value.reshape(2, -1, num_heads, head_size)
key_value = key_value.transpose((0, 2, 1, 3))
kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size) key = key.reshape(-1, num_heads, head_size)
kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value) key = key.transpose((1, 0, 2))
kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size) value = value.reshape(-1, num_heads, head_size)
return kv_cache value = value.transpose((1, 0, 2))
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
return k_cache, v_cache
def write_to_kv_cache( def write_to_kv_cache(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size] key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len] slot_mapping: jax.Array, # [batch_size, seq_len]
) -> jax.Array: ) -> Tuple[jax.Array, jax.Array]:
"""In-place write to KV cache."""
batch_size = slot_mapping.shape[0] batch_size = slot_mapping.shape[0]
key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size]
def cond(val: _IteratorState): def cond(val: _IteratorState):
return val.idx < batch_size return val.idx < batch_size
def body(val: _IteratorState): def body(val: _IteratorState):
val.kv_cache = _write_seq_to_kv_cache( k_cache, v_cache = _write_seq_to_kv_cache(
key_value[val.idx], key[val.idx],
val.kv_cache, value[val.idx],
val.k_cache,
val.v_cache,
slot_mapping[val.idx], slot_mapping[val.idx],
) )
val.k_cache = k_cache
val.v_cache = v_cache
val.idx += 1 val.idx += 1
return val return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache) iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator) iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache return iterator.k_cache, iterator.v_cache
def _write_seq_to_kv_cache( def _write_seq_to_kv_cache(
key_value: jax.Array, # [seq_len, 2, num_heads, head_size] key: jax.Array, # [seq_len, num_heads, head_size]
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] value: jax.Array, # [seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [seq_len] slot_mapping: jax.Array, # [seq_len]
) -> jax.Array: ) -> Tuple[jax.Array, jax.Array]:
seq_len = slot_mapping.shape[0] seq_len = slot_mapping.shape[0]
num_heads, _, block_size, head_size = kv_cache.shape[1:] num_heads, _, head_size = k_cache.shape
# Reshape to match the rank of kv_cache. # Reshape to match the rank of kv_cache.
key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size) key = key.reshape(seq_len, num_heads, 1, head_size)
value = value.reshape(seq_len, num_heads, 1, head_size)
def cond(val: _IteratorState): def cond(val: _IteratorState):
return jnp.logical_and( return jnp.logical_and(
@ -66,21 +74,27 @@ def _write_seq_to_kv_cache(
def body(val: _IteratorState): def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx] slot_idx = slot_mapping[val.idx]
val.kv_cache = jax.lax.dynamic_update_slice( val.k_cache = jax.lax.dynamic_update_slice(
val.kv_cache, val.k_cache,
key_value[val.idx], key[val.idx],
(0, 0, slot_idx // block_size, slot_idx % block_size, 0), (0, slot_idx, 0),
)
val.v_cache = jax.lax.dynamic_update_slice(
val.v_cache,
value[val.idx],
(0, slot_idx, 0),
) )
val.idx += 1 val.idx += 1
return val return val
iterator = _IteratorState(idx=0, kv_cache=kv_cache) iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator) iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.kv_cache return iterator.k_cache, iterator.v_cache
@chex.dataclass @chex.dataclass
class _IteratorState: class _IteratorState:
idx: jnp.int32 idx: jnp.int32
kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size] k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]

View File

@ -80,11 +80,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size() head_size = self.model_config.get_head_size()
self.tpu_cache = [
jnp.zeros( self.tpu_cache = []
(2, num_kv_heads, num_gpu_blocks, self.block_size, head_size), for _ in range(num_layers):
dtype=dtype) for _ in range(num_layers) key_cache = jnp.zeros(
] (num_kv_heads, num_gpu_blocks * self.block_size, head_size),
dtype=dtype)
value_cache = jnp.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
self.model_runner.block_size = self.block_size self.model_runner.block_size = self.block_size
self._warmup_model() self._warmup_model()