From 743695f58625e72442327560a0bd4e54f7055a29 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 19 Apr 2024 07:51:54 +0000 Subject: [PATCH] Fix write_to_kv_cache --- vllm/model_executor/models/jax/gemma.py | 6 +- .../models/jax/ops/write_to_cache.py | 102 ++++++++++++++++-- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index e5861d4d2c665..c6bc5b0849d92 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -20,7 +20,7 @@ from transformers import GemmaConfig from vllm.model_executor.models.jax.ops.flash_attn import flash_attn from vllm.model_executor.models.jax.ops.paged_attn import paged_attn -from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache +from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache K_MASK = -2.3819763e38 # Set to a large negative number. @@ -154,9 +154,7 @@ class Attention(nn.Module): ) # Write the incoming keys and values to KV cache. - key_cache = write_to_cache(key_proj, cache[0], slot_mapping) - value_cache = write_to_cache(value_proj, cache[1], slot_mapping) - cache = jnp.stack([key_cache, value_cache]) + cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping) if block_tables is None: # Prompt attention. diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index cd54e90bafe16..890deacf86c78 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -1,14 +1,96 @@ +import chex import jax +import jax.numpy as jnp + +_PAD_SLOT_ID = -1 -def write_to_cache( - x: jax.Array, - cache: jax.Array, - slot_mapping: jax.Array, +def write_to_kv_cache( + key: jax.Array, # [batch_size, seq_len, num_heads, head_size] + value: jax.Array, # [batch_size, seq_len, num_heads, head_size] + kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + slot_mapping: jax.Array, # [batch_size, seq_len] ) -> jax.Array: - num_heads, num_blocks, block_size, head_size = cache.shape - cache = cache.reshape(num_heads, num_blocks * block_size, head_size) - x = x.reshape(-1, x.shape[-2], x.shape[-1]) - slot_mapping = slot_mapping.reshape(-1) - cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2)) - return cache.reshape(num_heads, num_blocks, block_size, head_size) + f = _write_to_kv_cache + return f(key, value, kv_cache, slot_mapping) + + +def _write_to_kv_cache( + key: jax.Array, # [batch_size, seq_len, num_heads, head_size] + value: jax.Array, # [batch_size, seq_len, num_heads, head_size] + kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + slot_mapping: jax.Array, # [batch_size, seq_len] +) -> jax.Array: + """Out-of-place write to KV cache.""" + num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:] + key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size] + key_value = key_value.reshape(2, -1, num_heads, head_size) + key_value = key_value.transpose((0, 2, 1, 3)) + + kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size) + kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value) + kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size) + return kv_cache + + +def _write_to_kv_cache_in_place( + key: jax.Array, # [batch_size, seq_len, num_heads, head_size] + value: jax.Array, # [batch_size, seq_len, num_heads, head_size] + kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + slot_mapping: jax.Array, # [batch_size, seq_len] +) -> jax.Array: + """In-place write to KV cache.""" + batch_size = slot_mapping.shape[0] + key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size] + + def cond(val: _IteratorState): + return val.idx < batch_size + + def body(val: _IteratorState): + val.kv_cache = _write_seq_to_kv_cache( + key_value[val.idx], + val.kv_cache, + slot_mapping[val.idx], + ) + val.idx += 1 + return val + + iterator = _IteratorState(idx=0, kv_cache=kv_cache) + iterator = jax.lax.while_loop(cond, body, iterator) + return iterator.kv_cache + + +def _write_seq_to_kv_cache( + key_value: jax.Array, # [seq_len, 2, num_heads, head_size] + kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] + slot_mapping: jax.Array, # [seq_len] +) -> jax.Array: + seq_len = slot_mapping.shape[0] + num_heads, _, block_size, head_size = kv_cache.shape[1:] + # Reshape to match the rank of kv_cache. + key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size) + + def cond(val: _IteratorState): + return jnp.logical_and( + val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID) + + def body(val: _IteratorState): + slot_idx = slot_mapping[val.idx] + val.kv_cache = jax.lax.dynamic_update_slice( + val.kv_cache, + key_value[val.idx], + (0, 0, slot_idx // block_size, slot_idx % block_size, 0), + ) + val.idx += 1 + return val + + iterator = _IteratorState(idx=0, kv_cache=kv_cache) + iterator = jax.lax.while_loop(cond, body, iterator) + return iterator.kv_cache + + +@chex.dataclass +class _IteratorState: + + idx: jnp.int32 + kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size]