From 028f528aad8a6fdcc922d5de75691e317c50ab7d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 25 Apr 2024 23:38:07 +0000 Subject: [PATCH] Fix KV cache shape --- vllm/model_executor/models/jax/gemma.py | 22 +++--- .../models/jax/ops/paged_attn.py | 13 +++- .../models/jax/ops/write_to_cache.py | 78 +++++++++++-------- vllm/worker/tpu_worker.py | 13 ++-- 4 files changed, 77 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index c6bc5b0849d92..9ed32a23b2f56 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Gemma transformer.""" +from typing import List, Tuple + import jax import jax.numpy as jnp from flax import linen as nn @@ -134,7 +136,7 @@ class Attention(nn.Module): slot_mapping: jax.Array, block_tables: jax.Array | None, context_lens: jax.Array | None, - cache: jax.Array, + cache: Tuple[jax.Array, jax.Array], ) -> tuple[jax.Array, jax.Array]: if self.use_qkv_einsum: query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x) @@ -154,7 +156,9 @@ class Attention(nn.Module): ) # Write the incoming keys and values to KV cache. - cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping) + k_cache, v_cache = cache + k_cache, v_cache = write_to_kv_cache( + key_proj, value_proj, k_cache, v_cache, slot_mapping) if block_tables is None: # Prompt attention. @@ -187,15 +191,15 @@ class Attention(nn.Module): # Decode attention. output = paged_attn( query_proj, - cache[0], - cache[1], + k_cache, + v_cache, self.sm_scale, block_tables, context_lens, ) attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output) - return cache, attn_output + return (k_cache, v_cache), attn_output class FeedForward(nn.Module): @@ -253,8 +257,8 @@ class Block(nn.Module): slot_mapping: jax.Array, block_tables: jax.Array | None, context_lens: jax.Array | None, - cache: jax.Array, - ) -> tuple[jax.Array, jax.Array]: + cache: Tuple[jax.Array, jax.Array], + ) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]: inputs_normalized = self.pre_attention_norm(x) cache, attn_output = self.attn( inputs_normalized, @@ -302,9 +306,9 @@ class Transformer(nn.Module): slot_mapping: jax.Array, block_tables: jax.Array | None, context_lens: jax.Array | None, - kv_caches: list[jax.Array], + kv_caches: List[Tuple[jax.Array, jax.Array]], logits_indices: jax.Array, - ) -> tuple[jax.Array, list[jax.Array]]: + ) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]: x = self.embedder.encode(token_ids) new_caches = [] for i, block in enumerate(self.blocks): diff --git a/vllm/model_executor/models/jax/ops/paged_attn.py b/vllm/model_executor/models/jax/ops/paged_attn.py index fc53651f9e1a1..9b099c7564353 100644 --- a/vllm/model_executor/models/jax/ops/paged_attn.py +++ b/vllm/model_executor/models/jax/ops/paged_attn.py @@ -4,20 +4,27 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention def paged_attn( q: jax.Array, # [batch, 1, num_heads, head_size] - k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] - v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size] + k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] sm_scale: float, block_tables: jax.Array, # [batch, max_num_blocks_per_batch] context_lens: jax.Array, # [batch] + block_size: int = 16, # FIXME(woosuk) ) -> jax.Array: # [batch, 1, num_heads, head_size] q = q.squeeze(1) q = q * sm_scale + + head_size = q.shape[-1] + num_slots = k_cache.shape[-2] + k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size) + v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size) + output = paged_attention( q, k_cache, v_cache, context_lens, block_tables, - pages_per_compute_block=4, + pages_per_compute_block=4, # TODO(woosuk): Tune this value. ) return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index 66fb8d659e316..8124bd6fd26bc 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -8,57 +8,65 @@ _PAD_SLOT_ID = -1 def _write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] - kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] slot_mapping: jax.Array, # [batch_size, seq_len] -) -> jax.Array: - """Out-of-place write to KV cache.""" - num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:] - key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size] - key_value = key_value.reshape(2, -1, num_heads, head_size) - key_value = key_value.transpose((0, 2, 1, 3)) +) -> Tuple[jax.Array, jax.Array]: + num_heads = key.shape[-2] + head_size = key.shape[-1] - kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size) - kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value) - kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size) - return kv_cache + key = key.reshape(-1, num_heads, head_size) + key = key.transpose((1, 0, 2)) + value = value.reshape(-1, num_heads, head_size) + value = value.transpose((1, 0, 2)) + + k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key) + v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value) + return k_cache, v_cache def write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] - kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] slot_mapping: jax.Array, # [batch_size, seq_len] -) -> jax.Array: - """In-place write to KV cache.""" +) -> Tuple[jax.Array, jax.Array]: batch_size = slot_mapping.shape[0] - key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size] def cond(val: _IteratorState): return val.idx < batch_size def body(val: _IteratorState): - val.kv_cache = _write_seq_to_kv_cache( - key_value[val.idx], - val.kv_cache, + k_cache, v_cache = _write_seq_to_kv_cache( + key[val.idx], + value[val.idx], + val.k_cache, + val.v_cache, slot_mapping[val.idx], ) + val.k_cache = k_cache + val.v_cache = v_cache val.idx += 1 return val - iterator = _IteratorState(idx=0, kv_cache=kv_cache) + iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache) iterator = jax.lax.while_loop(cond, body, iterator) - return iterator.kv_cache + return iterator.k_cache, iterator.v_cache def _write_seq_to_kv_cache( - key_value: jax.Array, # [seq_len, 2, num_heads, head_size] - kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + key: jax.Array, # [seq_len, num_heads, head_size] + value: jax.Array, # [seq_len, num_heads, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] slot_mapping: jax.Array, # [seq_len] -) -> jax.Array: +) -> Tuple[jax.Array, jax.Array]: seq_len = slot_mapping.shape[0] - num_heads, _, block_size, head_size = kv_cache.shape[1:] + num_heads, _, head_size = k_cache.shape # Reshape to match the rank of kv_cache. - key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size) + key = key.reshape(seq_len, num_heads, 1, head_size) + value = value.reshape(seq_len, num_heads, 1, head_size) def cond(val: _IteratorState): return jnp.logical_and( @@ -66,21 +74,27 @@ def _write_seq_to_kv_cache( def body(val: _IteratorState): slot_idx = slot_mapping[val.idx] - val.kv_cache = jax.lax.dynamic_update_slice( - val.kv_cache, - key_value[val.idx], - (0, 0, slot_idx // block_size, slot_idx % block_size, 0), + val.k_cache = jax.lax.dynamic_update_slice( + val.k_cache, + key[val.idx], + (0, slot_idx, 0), + ) + val.v_cache = jax.lax.dynamic_update_slice( + val.v_cache, + value[val.idx], + (0, slot_idx, 0), ) val.idx += 1 return val - iterator = _IteratorState(idx=0, kv_cache=kv_cache) + iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache) iterator = jax.lax.while_loop(cond, body, iterator) - return iterator.kv_cache + return iterator.k_cache, iterator.v_cache @chex.dataclass class _IteratorState: idx: jnp.int32 - kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size] + k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size] + v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size] diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index fb5f25639be6a..0fe632aa8923e 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -80,11 +80,14 @@ class TPUWorker(LoraNotSupportedWorkerBase): num_layers = self.model_config.get_num_layers(self.parallel_config) num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - self.tpu_cache = [ - jnp.zeros( - (2, num_kv_heads, num_gpu_blocks, self.block_size, head_size), - dtype=dtype) for _ in range(num_layers) - ] + + self.tpu_cache = [] + for _ in range(num_layers): + key_cache = jnp.zeros( + (num_kv_heads, num_gpu_blocks * self.block_size, head_size), + dtype=dtype) + value_cache = jnp.zeros_like(key_cache) + self.tpu_cache.append((key_cache, value_cache)) self.model_runner.block_size = self.block_size self._warmup_model()