[TPU] kv cache update kernel supports dynamic grid (#20235)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-07-01 23:33:37 -07:00 committed by GitHub
parent b205e8467d
commit 7da296be04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 17 deletions

View File

@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
new_kv_xla = new_kv_cpu.to(torch_xla.device()) new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
dtype=np.int32) dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array([ kv_cache_start_indices = np.array([
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
device="cpu", device="cpu",
dtype=torch.int32) dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
device=torch_xla.device(),
dtype=torch.int32)
torch_xla.sync() torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
num_slices_per_block) page_size, num_slices_per_block)
kv_cache_xla.copy_(new_kv_cache_xla) kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync() torch_xla.sync()

View File

@ -7,11 +7,13 @@ import jax
from jax.experimental import pallas as pl from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas import tpu as pltpu
from vllm.utils import cdiv
def _kv_cache_update_kernel( def _kv_cache_update_kernel(
# Prefetch # Prefetch
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# slice_len) # new_kv_start, slice_len)
# Input # Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
@ -70,6 +72,7 @@ def kv_cache_update(
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax. kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices: jax.Array, # [1]
*, *,
page_size: int = 32, page_size: int = 32,
num_slices_per_block: int = 8, num_slices_per_block: int = 8,
@ -107,7 +110,7 @@ def kv_cache_update(
num_scalar_prefetch=len(scalar_prefetches), num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs, in_specs=in_specs,
out_specs=out_specs, out_specs=out_specs,
grid=(slices.shape[1] // num_slices_per_block, ), grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
scratch_shapes=scratch_shapes, scratch_shapes=scratch_shapes,
), ),
out_shape=out_shape, out_shape=out_shape,

View File

@ -111,6 +111,7 @@ class PallasMetadata:
context_lens: torch.Tensor context_lens: torch.Tensor
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
num_seqs: torch.Tensor num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int num_slices_per_kv_cache_update_block: int
@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache( write_to_kv_cache(
key, value, kv_cache, slot_mapping, key, value, kv_cache, slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block) attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices)
output = torch.ops.xla.ragged_paged_attention( output = torch.ops.xla.ragged_paged_attention(
query, query,
@ -252,6 +254,7 @@ def write_to_kv_cache(
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int, num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
) -> None: ) -> None:
""" Write the key and values to the KV cache. """ Write the key and values to the KV cache.
@ -271,7 +274,7 @@ def write_to_kv_cache(
kv_cache = kv_cache.flatten(0, 1) kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op( new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size, kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
num_slices_per_kv_cache_update_block) num_slices_per_kv_cache_update_block)
# NOTE: the in-place copy will be optimized away by XLA compiler. # NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache) kv_cache.copy_(new_kv_cache)
@ -279,32 +282,39 @@ def write_to_kv_cache(
@requires_jax @requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int, kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int): num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { new_kv_cache = xb.call_jax(
"page_size": page_size, kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"num_slices_per_block": num_slices_per_block "page_size": page_size,
}) "num_slices_per_block": num_slices_per_block
})
return new_kv_cache return new_kv_cache
XLA_LIB.define( XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"int page_size, int num_slices_per_block) -> Tensor", ) "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA") @impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int, kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor: num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, num_slices_per_block) num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int, kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor: num_slices_per_block: int) -> torch.Tensor:
return kv_cache return kv_cache

View File

@ -713,8 +713,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.device) self.device)
block_tables = block_tables.to(self.device) block_tables = block_tables.to(self.device)
# Calculate the slot mapping
slot_mapping_metadata = self._get_slot_mapping_metadata( slot_mapping_metadata = self._get_slot_mapping_metadata(
num_reqs, num_scheduled_tokens_per_req) num_reqs, num_scheduled_tokens_per_req)
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs, padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size) self.block_size)
@ -745,6 +747,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
num_seqs=torch.tensor([num_reqs], num_seqs=torch.tensor([num_reqs],
dtype=torch.int32, dtype=torch.int32,
device=self.device), device=self.device),
num_kv_update_slices=torch.tensor([num_kv_update_slices],
dtype=torch.int32,
device=self.device),
num_slices_per_kv_cache_update_block= num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
) )
@ -1174,6 +1179,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size) num_tokens, self.max_num_reqs, self.block_size)
num_kv_update_slices = torch.tensor([padded_num_slices],
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros((3, padded_num_slices), slot_mapping = torch.zeros((3, padded_num_slices),
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
block_tables = torch.zeros((num_reqs, num_blocks), block_tables = torch.zeros((num_reqs, num_blocks),
@ -1193,6 +1200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
context_lens=context_lens, context_lens=context_lens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
num_seqs=num_seqs, num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block= num_slices_per_kv_cache_update_block=
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
) )