diff --git a/vllm/model_executor/models/jax/ops/write_to_cache.py b/vllm/model_executor/models/jax/ops/write_to_cache.py index cdd562afcdf87..dc794803e5249 100644 --- a/vllm/model_executor/models/jax/ops/write_to_cache.py +++ b/vllm/model_executor/models/jax/ops/write_to_cache.py @@ -7,7 +7,7 @@ import jax.numpy as jnp _PAD_SLOT_ID = -1 -def _write_to_kv_cache( +def write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size] @@ -27,7 +27,7 @@ def _write_to_kv_cache( return k_cache, v_cache -def write_to_kv_cache( +def _write_to_kv_cache( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]