From 2a84fb422fc62ab29238dccbf7bdb214fc51c31e Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 9 Aug 2025 20:49:04 -0700 Subject: [PATCH] [TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block (#22394) Signed-off-by: Chengji Yao Co-authored-by: Chengji Yao --- tests/v1/tpu/test_kv_cache_update_kernel.py | 5 ----- vllm/attention/ops/pallas_kv_cache_update.py | 16 ++++++++++------ vllm/v1/worker/tpu_model_runner.py | 19 +++++++++---------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index f82737325e9b..acb607247d75 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -43,11 +43,6 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, np.cumsum(slice_lens[:-1])]) slot_mapping = np.stack( [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) - padded_size = (slot_mapping.shape[0] + num_slices_per_block - - 1) // num_slices_per_block * num_slices_per_block - slot_mapping = np.pad(slot_mapping, - [[0, padded_size - slot_mapping.shape[0]], [0, 0]], - constant_values=0) slot_mapping = np.transpose(slot_mapping) slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index e7d727a45e91..d75983bd407d 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -14,6 +14,7 @@ def _kv_cache_update_kernel( # Prefetch slices_ref, # [3, padded_num_slices], list of (kv_cache_start, # new_kv_start, slice_len) + num_slices_ref, # [1] # Input new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, @@ -32,8 +33,10 @@ def _kv_cache_update_kernel( # Copy from new_kv_hbm_ref to scratch for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - new_kv_start = slices_ref[1, offset_i] - length = slices_ref[2, offset_i] + new_kv_start = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[1, offset_i], 0) + length = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[2, offset_i], 0) async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], @@ -49,8 +52,10 @@ def _kv_cache_update_kernel( async_copies.clear() for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - kv_cache_start = slices_ref[0, offset_i] - length = slices_ref[2, offset_i] + kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[0, offset_i], 0) + length = jax.lax.select(offset_i < num_slices_ref[0], + slices_ref[2, offset_i], 0) async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], @@ -77,7 +82,6 @@ def kv_cache_update( page_size: int = 32, num_slices_per_block: int = 8, ): - assert slices.shape[1] % num_slices_per_block == 0 _, num_combined_kv_heads, head_dim = new_kv.shape assert kv_cache.shape[1] == num_combined_kv_heads assert kv_cache.shape[2] == head_dim @@ -93,7 +97,7 @@ def kv_cache_update( out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] - scalar_prefetches = [slices] + scalar_prefetches = [slices, num_kv_update_slices] scratch = pltpu.VMEM( (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), new_kv.dtype, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 442c0ea068b9..915869726fbf 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -745,7 +745,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size, self._num_slices_per_kv_cache_update_block) + self.block_size) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], @@ -1244,8 +1244,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size, - self._num_slices_per_kv_cache_update_block) + num_tokens, self.max_num_reqs, self.block_size) num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(self.device) slot_mapping = torch.zeros((3, padded_num_slices), @@ -1963,17 +1962,17 @@ def copy_kv_blocks( _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) -def _get_padded_num_kv_cache_update_slices( - num_tokens: int, max_num_reqs: int, page_size: int, - num_slices_per_kv_cache_update_block: int) -> int: +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" + # NOTE(chengjiyao): let's say R_i is the token num for i-th request, + # so it occupies most 2 + R_i // page_size pages. The total maximum + # possible number of pages needed is sum(2 + R_i // page_size), which + # is <= 2 * max_num_reqs + sum(R_i) // page_size + # = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) - padded_num_slices = ( - padded_num_slices + num_slices_per_kv_cache_update_block - 1 - ) // num_slices_per_kv_cache_update_block * \ - num_slices_per_kv_cache_update_block return padded_num_slices