mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 23:37:09 +08:00
Fix write_to_kv_cache
This commit is contained in:
parent
62b870fa07
commit
743695f586
@ -20,7 +20,7 @@ from transformers import GemmaConfig
|
||||
|
||||
from vllm.model_executor.models.jax.ops.flash_attn import flash_attn
|
||||
from vllm.model_executor.models.jax.ops.paged_attn import paged_attn
|
||||
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_cache
|
||||
from vllm.model_executor.models.jax.ops.write_to_cache import write_to_kv_cache
|
||||
|
||||
K_MASK = -2.3819763e38 # Set to a large negative number.
|
||||
|
||||
@ -154,9 +154,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Write the incoming keys and values to KV cache.
|
||||
key_cache = write_to_cache(key_proj, cache[0], slot_mapping)
|
||||
value_cache = write_to_cache(value_proj, cache[1], slot_mapping)
|
||||
cache = jnp.stack([key_cache, value_cache])
|
||||
cache = write_to_kv_cache(key_proj, value_proj, cache, slot_mapping)
|
||||
|
||||
if block_tables is None:
|
||||
# Prompt attention.
|
||||
|
||||
@ -1,14 +1,96 @@
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def write_to_cache(
|
||||
x: jax.Array,
|
||||
cache: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
def write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> jax.Array:
|
||||
num_heads, num_blocks, block_size, head_size = cache.shape
|
||||
cache = cache.reshape(num_heads, num_blocks * block_size, head_size)
|
||||
x = x.reshape(-1, x.shape[-2], x.shape[-1])
|
||||
slot_mapping = slot_mapping.reshape(-1)
|
||||
cache = cache.at[:, slot_mapping, :].set(x.transpose(1, 0, 2))
|
||||
return cache.reshape(num_heads, num_blocks, block_size, head_size)
|
||||
f = _write_to_kv_cache
|
||||
return f(key, value, kv_cache, slot_mapping)
|
||||
|
||||
|
||||
def _write_to_kv_cache(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> jax.Array:
|
||||
"""Out-of-place write to KV cache."""
|
||||
num_heads, num_blocks, block_size, head_size = kv_cache.shape[1:]
|
||||
key_value = jnp.stack([key, value]) # [2, batch_size, seq_len, num_heads, head_size]
|
||||
key_value = key_value.reshape(2, -1, num_heads, head_size)
|
||||
key_value = key_value.transpose((0, 2, 1, 3))
|
||||
|
||||
kv_cache = kv_cache.reshape(2, num_heads, num_blocks * block_size, head_size)
|
||||
kv_cache = kv_cache.at[:, :, slot_mapping.reshape(-1), :].set(key_value)
|
||||
kv_cache = kv_cache.reshape(2, num_heads, num_blocks, block_size, head_size)
|
||||
return kv_cache
|
||||
|
||||
|
||||
def _write_to_kv_cache_in_place(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> jax.Array:
|
||||
"""In-place write to KV cache."""
|
||||
batch_size = slot_mapping.shape[0]
|
||||
key_value = jnp.stack([key, value], axis=2) # [batch_size, seq_len, 2, num_heads, head_size]
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return val.idx < batch_size
|
||||
|
||||
def body(val: _IteratorState):
|
||||
val.kv_cache = _write_seq_to_kv_cache(
|
||||
key_value[val.idx],
|
||||
val.kv_cache,
|
||||
slot_mapping[val.idx],
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.kv_cache
|
||||
|
||||
|
||||
def _write_seq_to_kv_cache(
|
||||
key_value: jax.Array, # [seq_len, 2, num_heads, head_size]
|
||||
kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]
|
||||
slot_mapping: jax.Array, # [seq_len]
|
||||
) -> jax.Array:
|
||||
seq_len = slot_mapping.shape[0]
|
||||
num_heads, _, block_size, head_size = kv_cache.shape[1:]
|
||||
# Reshape to match the rank of kv_cache.
|
||||
key_value = key_value.reshape(seq_len, 2, num_heads, 1, 1, head_size)
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return jnp.logical_and(
|
||||
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||
|
||||
def body(val: _IteratorState):
|
||||
slot_idx = slot_mapping[val.idx]
|
||||
val.kv_cache = jax.lax.dynamic_update_slice(
|
||||
val.kv_cache,
|
||||
key_value[val.idx],
|
||||
(0, 0, slot_idx // block_size, slot_idx % block_size, 0),
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, kv_cache=kv_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.kv_cache
|
||||
|
||||
|
||||
@chex.dataclass
|
||||
class _IteratorState:
|
||||
|
||||
idx: jnp.int32
|
||||
kv_cache: jnp.ndarray # [2, num_heads, num_blocks, block_size, head_size]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user