mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 16:39:09 +08:00
Fix KV cache shape
This commit is contained in:
parent
fa5bacd5b0
commit
028f528aad
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gemma transformer."""
|
||||
from typing import List, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
@ -134,7 +136,7 @@ class Attention(nn.Module):
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
cache: jax.Array,
|
||||
cache: Tuple[jax.Array, jax.Array],
|
||||
) -> tuple[jax.Array, jax.Array]:
|
||||
if self.use_qkv_einsum:
|
||||
query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x)
|
||||
@ -154,7 +156,9 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Write the incoming keys and values to KV cache.
|
||||
cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping)
|
||||
k_cache, v_cache = cache
|
||||
k_cache, v_cache = write_to_kv_cache(
|
||||
key_proj, value_proj, k_cache, v_cache, slot_mapping)
|
||||
|
||||
if block_tables is None:
|
||||
# Prompt attention.
|
||||
@ -187,15 +191,15 @@ class Attention(nn.Module):
|
||||
# Decode attention.
|
||||
output = paged_attn(
|
||||
query_proj,
|
||||
cache[0],
|
||||
cache[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.sm_scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
|
||||
attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', output)
|
||||
return cache, attn_output
|
||||
return (k_cache, v_cache), attn_output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -253,8 +257,8 @@ class Block(nn.Module):
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
cache: jax.Array,
|
||||
) -> tuple[jax.Array, jax.Array]:
|
||||
cache: Tuple[jax.Array, jax.Array],
|
||||
) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:
|
||||
inputs_normalized = self.pre_attention_norm(x)
|
||||
cache, attn_output = self.attn(
|
||||
inputs_normalized,
|
||||
@ -302,9 +306,9 @@ class Transformer(nn.Module):
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: jax.Array | None,
|
||||
context_lens: jax.Array | None,
|
||||
kv_caches: list[jax.Array],
|
||||
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
logits_indices: jax.Array,
|
||||
) -> tuple[jax.Array, list[jax.Array]]:
|
||||
) -> tuple[jax.Array, List[Tuple[jax.Array, jax.Array]]]:
|
||||
x = self.embedder.encode(token_ids)
|
||||
new_caches = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
|
||||
@ -4,20 +4,27 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||
|
||||
def paged_attn(
|
||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
sm_scale: float,
|
||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||
context_lens: jax.Array, # [batch]
|
||||
block_size: int = 16, # FIXME(woosuk)
|
||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||
q = q.squeeze(1)
|
||||
q = q * sm_scale
|
||||
|
||||
head_size = q.shape[-1]
|
||||
num_slots = k_cache.shape[-2]
|
||||
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
|
||||
output = paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block=4,
|
||||
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
|
||||
)
|
||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||
|
||||
@ -8,57 +8,65 @@ _PAD_SLOT_ID = -1
|
||||
def _write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> jax.Array:
|
||||
"""Out-of-place write to KV cache."""
|
||||
num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:]
|
||||
key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size]
|
||||
key_value = key_value.reshape(2, -1, num_heads, head_size)
|
||||
key_value = key_value.transpose((0, 2, 1, 3))
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
num_heads = key.shape[-2]
|
||||
head_size = key.shape[-1]
|
||||
|
||||
kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size)
|
||||
kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value)
|
||||
kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size)
|
||||
return kv_cache
|
||||
key = key.reshape(-1, num_heads, head_size)
|
||||
key = key.transpose((1, 0, 2))
|
||||
value = value.reshape(-1, num_heads, head_size)
|
||||
value = value.transpose((1, 0, 2))
|
||||
|
||||
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> jax.Array:
|
||||
"""In-place write to KV cache."""
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
batch_size = slot_mapping.shape[0]
|
||||
key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size]
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return val.idx < batch_size
|
||||
|
||||
def body(val: _IteratorState):
|
||||
val.kv_cache = _write_seq_to_kv_cache(
|
||||
key_value[val.idx],
|
||||
val.kv_cache,
|
||||
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||
key[val.idx],
|
||||
value[val.idx],
|
||||
val.k_cache,
|
||||
val.v_cache,
|
||||
slot_mapping[val.idx],
|
||||
)
|
||||
val.k_cache = k_cache
|
||||
val.v_cache = v_cache
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.kv_cache
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
def _write_seq_to_kv_cache(
|
||||
key_value: jax.Array, # [seq_len, 2, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [seq_len]
|
||||
) -> jax.Array:
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
seq_len = slot_mapping.shape[0]
|
||||
num_heads, _, block_size, head_size = kv_cache.shape[1:]
|
||||
num_heads, _, head_size = k_cache.shape
|
||||
# Reshape to match the rank of kv_cache.
|
||||
key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size)
|
||||
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return jnp.logical_and(
|
||||
@ -66,21 +74,27 @@ def _write_seq_to_kv_cache(
|
||||
|
||||
def body(val: _IteratorState):
|
||||
slot_idx = slot_mapping[val.idx]
|
||||
val.kv_cache = jax.lax.dynamic_update_slice(
|
||||
val.kv_cache,
|
||||
key_value[val.idx],
|
||||
(0, 0, slot_idx // block_size, slot_idx % block_size, 0),
|
||||
val.k_cache = jax.lax.dynamic_update_slice(
|
||||
val.k_cache,
|
||||
key[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.v_cache = jax.lax.dynamic_update_slice(
|
||||
val.v_cache,
|
||||
value[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.kv_cache
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@chex.dataclass
|
||||
class _IteratorState:
|
||||
|
||||
idx: jnp.int32
|
||||
kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size]
|
||||
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
|
||||
@ -80,11 +80,14 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
self.tpu_cache = [
|
||||
jnp.zeros(
|
||||
(2, num_kv_heads, num_gpu_blocks, self.block_size, head_size),
|
||||
dtype=dtype) for _ in range(num_layers)
|
||||
]
|
||||
|
||||
self.tpu_cache = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = jnp.zeros(
|
||||
(num_kv_heads, num_gpu_blocks * self.block_size, head_size),
|
||||
dtype=dtype)
|
||||
value_cache = jnp.zeros_like(key_cache)
|
||||
self.tpu_cache.append((key_cache, value_cache))
|
||||
self.model_runner.block_size = self.block_size
|
||||
self._warmup_model()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user