From 21f35c22896eb1d8c5bcfe7731f488306960b846 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 05:00:26 +0000 Subject: [PATCH] Change version --- vllm/model_executor/models/jax/ops/write_to_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index cdd562afcdf87..dc794803e5249 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -7,7 +7,7 @@ import jax.numpy as jnp _PAD_SLOT_ID = -1 -def _write_to_kv_cache( +def write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] @@ -27,7 +27,7 @@ def _write_to_kv_cache( return k_cache, v_cache -def write_to_kv_cache( +def _write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]