diff --git a/benchmarks/bench_cache_write.py b/benchmarks/bench_cache_write.py new file mode 100644 index 0000000000000..2c2ba318bfe8f --- /dev/null +++ b/benchmarks/bench_cache_write.py @@ -0,0 +1,146 @@ +import time +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +_PAD_SLOT_ID = -1 + + +@jax.jit +def write_to_kv_cache1( + key: jax.Array, # [batch_size, seq_len, num_heads, head_size] + value: jax.Array, # [batch_size, seq_len, num_heads, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + slot_mapping: jax.Array, # [batch_size, seq_len] +) -> Tuple[jax.Array, jax.Array]: + num_heads = key.shape[-2] + head_size = key.shape[-1] + + key = key.reshape(-1, num_heads, head_size) + key = key.transpose((1, 0, 2)) + value = value.reshape(-1, num_heads, head_size) + value = value.transpose((1, 0, 2)) + + k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key) + v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value) + return k_cache, v_cache + + +@jax.jit +def write_to_kv_cache2( + key: jax.Array, # [batch_size, seq_len, num_heads, head_size] + value: jax.Array, # [batch_size, seq_len, num_heads, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + slot_mapping: jax.Array, # [batch_size, seq_len] +) -> Tuple[jax.Array, jax.Array]: + batch_size = slot_mapping.shape[0] + + def cond(val: _IteratorState): + return val.idx < batch_size + + def body(val: _IteratorState): + k_cache, v_cache = _write_seq_to_kv_cache( + key[val.idx], + value[val.idx], + val.k_cache, + val.v_cache, + slot_mapping[val.idx], + ) + val.k_cache = k_cache + val.v_cache = v_cache + val.idx += 1 + return val + + iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache) + iterator = jax.lax.while_loop(cond, body, iterator) + return iterator.k_cache, iterator.v_cache + + +def _write_seq_to_kv_cache( + key: jax.Array, # [seq_len, num_heads, head_size] + value: jax.Array, # [seq_len, num_heads, head_size] + k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] + slot_mapping: jax.Array, # [seq_len] +) -> Tuple[jax.Array, jax.Array]: + seq_len = slot_mapping.shape[0] + num_heads, _, head_size = k_cache.shape + # Reshape to match the rank of kv_cache. + key = key.reshape(seq_len, num_heads, 1, head_size) + value = value.reshape(seq_len, num_heads, 1, head_size) + + def cond(val: _IteratorState): + return jnp.logical_and( + val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID) + + def body(val: _IteratorState): + slot_idx = slot_mapping[val.idx] + val.k_cache = jax.lax.dynamic_update_slice( + val.k_cache, + key[val.idx], + (0, slot_idx, 0), + ) + val.v_cache = jax.lax.dynamic_update_slice( + val.v_cache, + value[val.idx], + (0, slot_idx, 0), + ) + val.idx += 1 + return val + + iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache) + iterator = jax.lax.while_loop(cond, body, iterator) + return iterator.k_cache, iterator.v_cache + + +@chex.dataclass +class _IteratorState: + + idx: jnp.int32 + k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size] + v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size] + + +def benchmark_write_to_kv_cache( + batch_size: int, + seq_len: int, + num_kv_heads: int, + head_size: int, + num_blocks: int, + block_size: int, + version: int = 1, +): + if version == 1: + f = write_to_kv_cache1 + elif version == 2: + f = write_to_kv_cache2 + else: + raise ValueError(f"Invalid version: {version}") + + rng_key = jax.random.PRNGKey(0) + key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16) + value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16) + k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32) + + # For JIT compilation. + k, v = f(key, value, k_cache, v_cache, slot_mapping) + k.block_until_ready() + + start = time.time() + for _ in range(100): + k, v = f(key, value, k_cache, v_cache, slot_mapping) + k.block_until_ready() + end = time.time() + print(f"Time taken: {(end - start) * 10:.2f} ms") + + +if __name__ == "__main__": + for num_blocks in [16, 256, 512, 1024, 2048]: + print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks") + benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1) diff --git a/benchmarks/bench_paged_attn.py b/benchmarks/bench_paged_attn.py new file mode 100644 index 0000000000000..2ff27622fd6cb --- /dev/null +++ b/benchmarks/bench_paged_attn.py @@ -0,0 +1,78 @@ +import time +import jax +import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention + +BLOCK_SIZE = 16 + +@jax.jit +def paged_attn( + q: jax.Array, # [batch, 1, num_heads, head_size] + k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] + sm_scale: float, + block_tables: jax.Array, # [batch, max_num_blocks_per_batch] + context_lens: jax.Array, # [batch] +) -> jax.Array: # [batch, 1, num_heads, head_size] + q = q.squeeze(1) + q = q * sm_scale + + head_size = q.shape[-1] + num_slots = k_cache.shape[-2] + k_cache = k_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size) + v_cache = v_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size) + + output = paged_attention( + q, + k_cache, + v_cache, + context_lens, + block_tables, + pages_per_compute_block=4, # TODO(woosuk): Tune this value. + ) + return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) + + +def benchmark_paged_attn( + batch_size: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + context_len: int, + num_blocks: int, +): + rng_key = jax.random.PRNGKey(0) + query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16) + k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) + v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) + sm_scale = BLOCK_SIZE ** -0.5 + block_tables = jax.random.randint(rng_key, (batch_size, context_len // BLOCK_SIZE), 0, num_blocks, dtype=jnp.int32) + context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32) + + # For JIT compilation. + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) + output.block_until_ready() + + start = time.time() + for _ in range(100): + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) + output.block_until_ready() + end = time.time() + + print(f"Time taken: {(end - start) * 10:.2f} ms") + + +if __name__ == "__main__": + + for num_blocks in [16, 256, 512, 2048]: + print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") + benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks) + + # BUG: This will raise the following error: + # jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid: + # INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. + # Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0xad3 (from TensorCoreSequencer:1:0xad4): + # no debugging message found for this tag:pc. HLO: custom-call.2; HLO computation: main.55 + num_blocks = 1024 + print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") + benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks)