From 620e7646d38a41e85f247577ed3b4beb261cef3a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Apr 2024 08:56:30 +0000 Subject: [PATCH] Fix cache write --- vllm/model_executor/models/jax/ops/write_to_cache.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index 890deacf86c78..66fb8d659e316 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -5,16 +5,6 @@ import jax.numpy as jnp _PAD_SLOT_ID = -1 -def write_to_kv_cache( - key: jax.Array, # [batch_size, seq_len, num_heads, head_size] - value: jax.Array, # [batch_size, seq_len, num_heads, head_size] - kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size] - slot_mapping: jax.Array, # [batch_size, seq_len] -) -> jax.Array: - f = _write_to_kv_cache - return f(key, value, kv_cache, slot_mapping) - - def _write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] @@ -33,7 +23,7 @@ def _write_to_kv_cache( return kv_cache -def _write_to_kv_cache_in_place( +def write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] kv_cache: jax.Array, # [2, num_heads, num_blocks, block_size, head_size]