mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 01:31:25 +08:00
[TPU] kv cache update kernel supports dynamic grid (#20235)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
b205e8467d
commit
7da296be04
@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
|||||||
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||||
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
|
num_kv_update_slices = len(slice_lens)
|
||||||
kv_cache_start_indices = np.array([
|
kv_cache_start_indices = np.array([
|
||||||
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
||||||
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
||||||
@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||||
|
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
|
||||||
|
device=torch_xla.device(),
|
||||||
|
dtype=torch.int32)
|
||||||
torch_xla.sync()
|
torch_xla.sync()
|
||||||
|
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
|
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
|
||||||
num_slices_per_block)
|
page_size, num_slices_per_block)
|
||||||
kv_cache_xla.copy_(new_kv_cache_xla)
|
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||||
torch_xla.sync()
|
torch_xla.sync()
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,13 @@ import jax
|
|||||||
from jax.experimental import pallas as pl
|
from jax.experimental import pallas as pl
|
||||||
from jax.experimental.pallas import tpu as pltpu
|
from jax.experimental.pallas import tpu as pltpu
|
||||||
|
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
def _kv_cache_update_kernel(
|
def _kv_cache_update_kernel(
|
||||||
# Prefetch
|
# Prefetch
|
||||||
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
|
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
|
||||||
# slice_len)
|
# new_kv_start, slice_len)
|
||||||
# Input
|
# Input
|
||||||
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||||
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||||
@ -70,6 +72,7 @@ def kv_cache_update(
|
|||||||
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||||
kv_cache: jax.
|
kv_cache: jax.
|
||||||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
num_kv_update_slices: jax.Array, # [1]
|
||||||
*,
|
*,
|
||||||
page_size: int = 32,
|
page_size: int = 32,
|
||||||
num_slices_per_block: int = 8,
|
num_slices_per_block: int = 8,
|
||||||
@ -107,7 +110,7 @@ def kv_cache_update(
|
|||||||
num_scalar_prefetch=len(scalar_prefetches),
|
num_scalar_prefetch=len(scalar_prefetches),
|
||||||
in_specs=in_specs,
|
in_specs=in_specs,
|
||||||
out_specs=out_specs,
|
out_specs=out_specs,
|
||||||
grid=(slices.shape[1] // num_slices_per_block, ),
|
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
|
||||||
scratch_shapes=scratch_shapes,
|
scratch_shapes=scratch_shapes,
|
||||||
),
|
),
|
||||||
out_shape=out_shape,
|
out_shape=out_shape,
|
||||||
|
|||||||
@ -111,6 +111,7 @@ class PallasMetadata:
|
|||||||
context_lens: torch.Tensor
|
context_lens: torch.Tensor
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
num_seqs: torch.Tensor
|
num_seqs: torch.Tensor
|
||||||
|
num_kv_update_slices: torch.Tensor
|
||||||
num_slices_per_kv_cache_update_block: int
|
num_slices_per_kv_cache_update_block: int
|
||||||
|
|
||||||
|
|
||||||
@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
write_to_kv_cache(
|
write_to_kv_cache(
|
||||||
key, value, kv_cache, slot_mapping,
|
key, value, kv_cache, slot_mapping,
|
||||||
attn_metadata.num_slices_per_kv_cache_update_block)
|
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||||
|
attn_metadata.num_kv_update_slices)
|
||||||
|
|
||||||
output = torch.ops.xla.ragged_paged_attention(
|
output = torch.ops.xla.ragged_paged_attention(
|
||||||
query,
|
query,
|
||||||
@ -252,6 +254,7 @@ def write_to_kv_cache(
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
num_slices_per_kv_cache_update_block: int,
|
num_slices_per_kv_cache_update_block: int,
|
||||||
|
num_kv_update_slices: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
""" Write the key and values to the KV cache.
|
""" Write the key and values to the KV cache.
|
||||||
|
|
||||||
@ -271,7 +274,7 @@ def write_to_kv_cache(
|
|||||||
|
|
||||||
kv_cache = kv_cache.flatten(0, 1)
|
kv_cache = kv_cache.flatten(0, 1)
|
||||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||||
kv, slot_mapping, kv_cache, page_size,
|
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
|
||||||
num_slices_per_kv_cache_update_block)
|
num_slices_per_kv_cache_update_block)
|
||||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||||
kv_cache.copy_(new_kv_cache)
|
kv_cache.copy_(new_kv_cache)
|
||||||
@ -279,32 +282,39 @@ def write_to_kv_cache(
|
|||||||
|
|
||||||
@requires_jax
|
@requires_jax
|
||||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
kv_cache: torch.Tensor, page_size: int,
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||||
num_slices_per_block: int):
|
num_slices_per_block: int):
|
||||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
|
new_kv_cache = xb.call_jax(
|
||||||
"page_size": page_size,
|
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||||
"num_slices_per_block": num_slices_per_block
|
"page_size": page_size,
|
||||||
})
|
"num_slices_per_block": num_slices_per_block
|
||||||
|
})
|
||||||
return new_kv_cache
|
return new_kv_cache
|
||||||
|
|
||||||
|
|
||||||
XLA_LIB.define(
|
XLA_LIB.define(
|
||||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
|
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
|
||||||
"int page_size, int num_slices_per_block) -> Tensor", )
|
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
|
||||||
|
"-> Tensor", )
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
kv_cache: torch.Tensor, page_size: int,
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||||
num_slices_per_block: int) -> torch.Tensor:
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||||
page_size, num_slices_per_block)
|
num_kv_update_slices, page_size,
|
||||||
|
num_slices_per_block)
|
||||||
return new_kv_cache
|
return new_kv_cache
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
kv_cache: torch.Tensor, page_size: int,
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor,
|
||||||
|
page_size: int,
|
||||||
num_slices_per_block: int) -> torch.Tensor:
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|||||||
@ -713,8 +713,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.device)
|
self.device)
|
||||||
block_tables = block_tables.to(self.device)
|
block_tables = block_tables.to(self.device)
|
||||||
|
|
||||||
|
# Calculate the slot mapping
|
||||||
slot_mapping_metadata = self._get_slot_mapping_metadata(
|
slot_mapping_metadata = self._get_slot_mapping_metadata(
|
||||||
num_reqs, num_scheduled_tokens_per_req)
|
num_reqs, num_scheduled_tokens_per_req)
|
||||||
|
num_kv_update_slices = slot_mapping_metadata.shape[0]
|
||||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||||
self.block_size)
|
self.block_size)
|
||||||
@ -745,6 +747,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_seqs=torch.tensor([num_reqs],
|
num_seqs=torch.tensor([num_reqs],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
|
num_kv_update_slices=torch.tensor([num_kv_update_slices],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device),
|
||||||
num_slices_per_kv_cache_update_block=
|
num_slices_per_kv_cache_update_block=
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||||
)
|
)
|
||||||
@ -1174,6 +1179,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
num_tokens, self.max_num_reqs, self.block_size)
|
num_tokens, self.max_num_reqs, self.block_size)
|
||||||
|
num_kv_update_slices = torch.tensor([padded_num_slices],
|
||||||
|
dtype=torch.int32).to(self.device)
|
||||||
slot_mapping = torch.zeros((3, padded_num_slices),
|
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
block_tables = torch.zeros((num_reqs, num_blocks),
|
block_tables = torch.zeros((num_reqs, num_blocks),
|
||||||
@ -1193,6 +1200,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
|
num_kv_update_slices=num_kv_update_slices,
|
||||||
num_slices_per_kv_cache_update_block=
|
num_slices_per_kv_cache_update_block=
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user