From 7da296be04933cfc29031f5bd1ba7cd28f376faa Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 1 Jul 2025 23:33:37 -0700 Subject: [PATCH] [TPU] kv cache update kernel supports dynamic grid (#20235) Signed-off-by: Chengji Yao --- tests/v1/tpu/test_kv_cache_update_kernel.py | 8 +++-- vllm/attention/ops/pallas_kv_cache_update.py | 9 ++++-- vllm/v1/attention/backends/pallas.py | 34 +++++++++++++------- vllm/v1/worker/tpu_model_runner.py | 8 +++++ 4 files changed, 42 insertions(+), 17 deletions(-) diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index 63a1f6777e4df..f82737325e9b8 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, new_kv_xla = new_kv_cpu.to(torch_xla.device()) slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32) + num_kv_update_slices = len(slice_lens) kv_cache_start_indices = np.array([ page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 @@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, device="cpu", dtype=torch.int32) slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) + num_kv_update_slices_xla = torch.tensor([num_kv_update_slices], + device=torch_xla.device(), + dtype=torch.int32) torch_xla.sync() torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( - new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, - num_slices_per_block) + new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla, + page_size, num_slices_per_block) kv_cache_xla.copy_(new_kv_cache_xla) torch_xla.sync() diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index 1a92b10e4f9c7..e7d727a45e91c 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -7,11 +7,13 @@ import jax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from vllm.utils import cdiv + def _kv_cache_update_kernel( # Prefetch - slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, - # slice_len) + slices_ref, # [3, padded_num_slices], list of (kv_cache_start, + # new_kv_start, slice_len) # Input new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, @@ -70,6 +72,7 @@ def kv_cache_update( Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) kv_cache: jax. Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + num_kv_update_slices: jax.Array, # [1] *, page_size: int = 32, num_slices_per_block: int = 8, @@ -107,7 +110,7 @@ def kv_cache_update( num_scalar_prefetch=len(scalar_prefetches), in_specs=in_specs, out_specs=out_specs, - grid=(slices.shape[1] // num_slices_per_block, ), + grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ), scratch_shapes=scratch_shapes, ), out_shape=out_shape, diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 49f0772c62d13..253d79d925cef 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -111,6 +111,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor + num_kv_update_slices: torch.Tensor num_slices_per_kv_cache_update_block: int @@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl): slot_mapping = attn_metadata.slot_mapping write_to_kv_cache( key, value, kv_cache, slot_mapping, - attn_metadata.num_slices_per_kv_cache_update_block) + attn_metadata.num_slices_per_kv_cache_update_block, + attn_metadata.num_kv_update_slices) output = torch.ops.xla.ragged_paged_attention( query, @@ -252,6 +254,7 @@ def write_to_kv_cache( kv_cache: torch.Tensor, slot_mapping: torch.Tensor, num_slices_per_kv_cache_update_block: int, + num_kv_update_slices: torch.Tensor, ) -> None: """ Write the key and values to the KV cache. @@ -271,7 +274,7 @@ def write_to_kv_cache( kv_cache = kv_cache.flatten(0, 1) new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, slot_mapping, kv_cache, page_size, + kv, slot_mapping, kv_cache, num_kv_update_slices, page_size, num_slices_per_kv_cache_update_block) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) @@ -279,32 +282,39 @@ def write_to_kv_cache( @requires_jax def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, page_size: int, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, page_size: int, num_slices_per_block: int): from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) + new_kv_cache = xb.call_jax( + kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) return new_kv_cache XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " - "int page_size, int num_slices_per_block) -> Tensor", ) + "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ + "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ + "-> Tensor", ) @impl(XLA_LIB, "kv_cache_update_op", "XLA") def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, page_size: int, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, page_size: int, num_slices_per_block: int) -> torch.Tensor: new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - page_size, num_slices_per_block) + num_kv_update_slices, page_size, + num_slices_per_block) return new_kv_cache @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, page_size: int, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, num_slices_per_block: int) -> torch.Tensor: return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0cc218bdb646f..f5f26d8fff98a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -713,8 +713,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.device) block_tables = block_tables.to(self.device) + # Calculate the slot mapping slot_mapping_metadata = self._get_slot_mapping_metadata( num_reqs, num_scheduled_tokens_per_req) + num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size) @@ -745,6 +747,9 @@ class TPUModelRunner(LoRAModelRunnerMixin): num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_kv_update_slices=torch.tensor([num_kv_update_slices], + dtype=torch.int32, + device=self.device), num_slices_per_kv_cache_update_block= NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) @@ -1174,6 +1179,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( num_tokens, self.max_num_reqs, self.block_size) + num_kv_update_slices = torch.tensor([padded_num_slices], + dtype=torch.int32).to(self.device) slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(self.device) block_tables = torch.zeros((num_reqs, num_blocks), @@ -1193,6 +1200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_kv_update_slices=num_kv_update_slices, num_slices_per_kv_cache_update_block= NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, )