mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:54:59 +08:00
[TPU] add kv cache update kernel (#19928)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
b69781f107
commit
04e1642e32
@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
|||||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||||
|
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
|
||||||
|
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||||
|
|
||||||
# After all tests have been attempted, exit with the overall status.
|
# After all tests have been attempted, exit with the overall status.
|
||||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||||
|
|||||||
71
tests/v1/tpu/test_kv_cache_update_kernel.py
Normal file
71
tests/v1/tpu/test_kv_cache_update_kernel.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch_xla
|
||||||
|
|
||||||
|
import vllm.v1.attention.backends.pallas # noqa: F401
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This is a test for TPU only")
|
||||||
|
@pytest.mark.parametrize("page_size", [32, 33])
|
||||||
|
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
|
||||||
|
@pytest.mark.parametrize("head_dim", [128, 256])
|
||||||
|
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
|
||||||
|
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
|
||||||
|
head_dim: int, num_slices_per_block: int):
|
||||||
|
page_num = 1000
|
||||||
|
padded_num_tokens = 128
|
||||||
|
kv_cache_cpu = torch.zeros(
|
||||||
|
(page_num * page_size, combined_kv_head_num, head_dim),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cpu")
|
||||||
|
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
|
||||||
|
new_kv_cpu = torch.randn(
|
||||||
|
(padded_num_tokens, combined_kv_head_num, head_dim),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cpu")
|
||||||
|
new_kv_xla = new_kv_cpu.to(torch_xla.device())
|
||||||
|
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
|
||||||
|
dtype=np.int32)
|
||||||
|
kv_cache_start_indices = np.array([
|
||||||
|
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
|
||||||
|
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
|
||||||
|
],
|
||||||
|
dtype=np.int32)
|
||||||
|
new_kv_cache_indices = np.concatenate(
|
||||||
|
[np.array([0], dtype=np.int32),
|
||||||
|
np.cumsum(slice_lens[:-1])])
|
||||||
|
slot_mapping = np.stack(
|
||||||
|
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
|
||||||
|
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
|
||||||
|
1) // num_slices_per_block * num_slices_per_block
|
||||||
|
slot_mapping = np.pad(slot_mapping,
|
||||||
|
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
|
||||||
|
constant_values=0)
|
||||||
|
slot_mapping = np.transpose(slot_mapping)
|
||||||
|
slot_mapping_cpu = torch.tensor(slot_mapping,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
|
||||||
|
torch_xla.sync()
|
||||||
|
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
|
||||||
|
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
|
||||||
|
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
|
||||||
|
num_slices_per_block)
|
||||||
|
kv_cache_xla.copy_(new_kv_cache_xla)
|
||||||
|
torch_xla.sync()
|
||||||
|
|
||||||
|
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
|
||||||
|
slice_lens):
|
||||||
|
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
|
||||||
|
|
||||||
|
assert torch.allclose(kv_cache_xla.cpu(),
|
||||||
|
kv_cache_cpu,
|
||||||
|
atol=1e-4,
|
||||||
|
rtol=1e-4)
|
||||||
@ -47,7 +47,7 @@ def test_ragged_paged_attention():
|
|||||||
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||||
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||||
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||||
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
|
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
|
||||||
max_num_reqs = 8
|
max_num_reqs = 8
|
||||||
max_num_blocks_per_req = 8
|
max_num_blocks_per_req = 8
|
||||||
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||||
@ -65,6 +65,7 @@ def test_ragged_paged_attention():
|
|||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
|
num_slices_per_kv_cache_update_block=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("torch.ops.xla.ragged_paged_attention"
|
with patch("torch.ops.xla.ragged_paged_attention"
|
||||||
|
|||||||
117
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
117
vllm/attention/ops/pallas_kv_cache_update.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax.experimental import pallas as pl
|
||||||
|
from jax.experimental.pallas import tpu as pltpu
|
||||||
|
|
||||||
|
|
||||||
|
def _kv_cache_update_kernel(
|
||||||
|
# Prefetch
|
||||||
|
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
|
||||||
|
# slice_len)
|
||||||
|
# Input
|
||||||
|
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
|
||||||
|
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
# Output
|
||||||
|
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
# Scratch
|
||||||
|
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
|
||||||
|
# head_dim]
|
||||||
|
sem,
|
||||||
|
):
|
||||||
|
async_copies = []
|
||||||
|
block_idx = pl.program_id(0)
|
||||||
|
num_slices_per_block = scratch.shape[0]
|
||||||
|
|
||||||
|
# Copy from new_kv_hbm_ref to scratch
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
new_kv_start = slices_ref[1, offset_i]
|
||||||
|
length = slices_ref[2, offset_i]
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
# Copy from scratch to kv_cache_hbm_ref
|
||||||
|
async_copies.clear()
|
||||||
|
for i in range(num_slices_per_block):
|
||||||
|
offset_i = i + block_idx * num_slices_per_block
|
||||||
|
kv_cache_start = slices_ref[0, offset_i]
|
||||||
|
length = slices_ref[2, offset_i]
|
||||||
|
async_copy = pltpu.make_async_copy(
|
||||||
|
scratch.at[i, pl.ds(0, length), ...],
|
||||||
|
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
|
||||||
|
sem,
|
||||||
|
)
|
||||||
|
async_copy.start()
|
||||||
|
async_copies.append(async_copy)
|
||||||
|
for async_copy in async_copies:
|
||||||
|
async_copy.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(
|
||||||
|
jax.jit,
|
||||||
|
static_argnames=["page_size", "num_slices_per_block"],
|
||||||
|
)
|
||||||
|
def kv_cache_update(
|
||||||
|
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
|
||||||
|
slices: jax.
|
||||||
|
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
|
||||||
|
kv_cache: jax.
|
||||||
|
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
|
||||||
|
*,
|
||||||
|
page_size: int = 32,
|
||||||
|
num_slices_per_block: int = 8,
|
||||||
|
):
|
||||||
|
assert slices.shape[1] % num_slices_per_block == 0
|
||||||
|
_, num_combined_kv_heads, head_dim = new_kv.shape
|
||||||
|
assert kv_cache.shape[1] == num_combined_kv_heads
|
||||||
|
assert kv_cache.shape[2] == head_dim
|
||||||
|
assert head_dim % 128 == 0
|
||||||
|
# TODO: Add dynamic check to make sure that the all the slice lengths are
|
||||||
|
# smaller or equal to page_size
|
||||||
|
|
||||||
|
in_specs = [
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
|
||||||
|
]
|
||||||
|
|
||||||
|
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
|
||||||
|
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
|
||||||
|
|
||||||
|
scalar_prefetches = [slices]
|
||||||
|
scratch = pltpu.VMEM(
|
||||||
|
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
|
||||||
|
new_kv.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
scratch_shapes = [
|
||||||
|
scratch,
|
||||||
|
pltpu.SemaphoreType.DMA,
|
||||||
|
]
|
||||||
|
|
||||||
|
kernel = pl.pallas_call(
|
||||||
|
_kv_cache_update_kernel,
|
||||||
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||||
|
num_scalar_prefetch=len(scalar_prefetches),
|
||||||
|
in_specs=in_specs,
|
||||||
|
out_specs=out_specs,
|
||||||
|
grid=(slices.shape[1] // num_slices_per_block, ),
|
||||||
|
scratch_shapes=scratch_shapes,
|
||||||
|
),
|
||||||
|
out_shape=out_shape,
|
||||||
|
input_output_aliases={len(scalar_prefetches) + 1: 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
|
||||||
@ -5,8 +5,12 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# Required to register custom ops.
|
import torch_xla.core.xla_builder as xb
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
|
# Required to register custom ops.
|
||||||
|
from torch.library import impl
|
||||||
|
from torch_xla._internal.jax_workarounds import requires_jax
|
||||||
|
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
@ -107,6 +111,7 @@ class PallasMetadata:
|
|||||||
context_lens: torch.Tensor
|
context_lens: torch.Tensor
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
num_seqs: torch.Tensor
|
num_seqs: torch.Tensor
|
||||||
|
num_slices_per_kv_cache_update_block: int
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackendImpl(AttentionImpl):
|
class PallasAttentionBackendImpl(AttentionImpl):
|
||||||
@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
# Write input keys and values to the KV cache.
|
# Write input keys and values to the KV cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
write_to_kv_cache(
|
||||||
|
key, value, kv_cache, slot_mapping,
|
||||||
|
attn_metadata.num_slices_per_kv_cache_update_block)
|
||||||
|
|
||||||
output = torch.ops.xla.ragged_paged_attention(
|
output = torch.ops.xla.ragged_paged_attention(
|
||||||
query,
|
query,
|
||||||
@ -244,6 +251,7 @@ def write_to_kv_cache(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
|
num_slices_per_kv_cache_update_block: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
""" Write the key and values to the KV cache.
|
""" Write the key and values to the KV cache.
|
||||||
|
|
||||||
@ -251,9 +259,9 @@ def write_to_kv_cache(
|
|||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||||
|
num_slices_per_kv_cache_update_block: int
|
||||||
"""
|
"""
|
||||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||||
head_size = cdiv(head_size,
|
head_size = cdiv(head_size,
|
||||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||||
@ -262,4 +270,41 @@ def write_to_kv_cache(
|
|||||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||||
|
|
||||||
kv_cache = kv_cache.flatten(0, 1)
|
kv_cache = kv_cache.flatten(0, 1)
|
||||||
kv_cache.index_copy_(0, slot_mapping, kv)
|
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||||
|
kv, slot_mapping, kv_cache, page_size,
|
||||||
|
num_slices_per_kv_cache_update_block)
|
||||||
|
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||||
|
kv_cache.copy_(new_kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@requires_jax
|
||||||
|
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor, page_size: int,
|
||||||
|
num_slices_per_block: int):
|
||||||
|
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||||
|
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
|
||||||
|
"page_size": page_size,
|
||||||
|
"num_slices_per_block": num_slices_per_block
|
||||||
|
})
|
||||||
|
return new_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
XLA_LIB.define(
|
||||||
|
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
|
||||||
|
"int page_size, int num_slices_per_block) -> Tensor", )
|
||||||
|
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||||
|
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor, page_size: int,
|
||||||
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
|
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||||
|
page_size, num_slices_per_block)
|
||||||
|
return new_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||||
|
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor, page_size: int,
|
||||||
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
|
return kv_cache
|
||||||
|
|||||||
@ -53,12 +53,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
|
||||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
|
||||||
_PAD_SLOT_ID = 1_000_000_000
|
|
||||||
INVALID_TOKEN_ID = -1
|
INVALID_TOKEN_ID = -1
|
||||||
# Smallest output size
|
# Smallest output size
|
||||||
MIN_NUM_SEQS = 8
|
MIN_NUM_SEQS = 8
|
||||||
|
# Block size used for kv cache updating kernel
|
||||||
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
|
||||||
|
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -526,6 +525,69 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
def _get_slot_mapping_metadata(self, num_reqs,
|
||||||
|
num_scheduled_tokens_per_req):
|
||||||
|
"""
|
||||||
|
Computes metadata for mapping slots to blocks in the key-value (KV)
|
||||||
|
cache for a batch of requests.
|
||||||
|
|
||||||
|
This function determines, for each request in the batch, how the
|
||||||
|
scheduled tokens are distributed across memory blocks, and generates
|
||||||
|
metadata needed to map slices of tokens to their corresponding positions
|
||||||
|
in the KV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_reqs (int): Number of requests in the current batch.
|
||||||
|
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
|
||||||
|
to be scheduled for each request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: A 2D array of shape (total_block_len, 3), where each row
|
||||||
|
contains:
|
||||||
|
- kv_cache_start_index (int): The starting index in the KV cache
|
||||||
|
for the corresponding slice.
|
||||||
|
- new_kv_start_index (int): The starting index in the new KV
|
||||||
|
cache for the corresponding slice.
|
||||||
|
- slice_len (int): The length of the slice.
|
||||||
|
"""
|
||||||
|
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||||
|
slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \
|
||||||
|
num_scheduled_tokens_per_req
|
||||||
|
local_block_start_idx = slices_start // self.block_size
|
||||||
|
local_block_end_idx = (slices_end - 1) // self.block_size
|
||||||
|
no_repeat_req_indices = self.arange_np[:num_reqs]
|
||||||
|
global_block_start_idx = (
|
||||||
|
no_repeat_req_indices * self.max_num_blocks_per_req +
|
||||||
|
local_block_start_idx)
|
||||||
|
block_lens = local_block_end_idx - local_block_start_idx + 1
|
||||||
|
global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
|
||||||
|
slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
|
||||||
|
global_block_indices = global_block_start_idx + slice_arange
|
||||||
|
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
||||||
|
block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
|
||||||
|
total_block_len = np.sum(block_lens)
|
||||||
|
slot_mapping_slices = np.repeat(np.array([[0, self.block_size]],
|
||||||
|
dtype=np.int32),
|
||||||
|
total_block_len,
|
||||||
|
axis=0)
|
||||||
|
cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
|
||||||
|
np.cumsum(block_lens, out=cu_block_lens[1:])
|
||||||
|
for req_idx in range(num_reqs):
|
||||||
|
slot_mapping_slices[cu_block_lens[req_idx]][
|
||||||
|
0] = slices_start[req_idx] % self.block_size
|
||||||
|
slot_mapping_slices[
|
||||||
|
cu_block_lens[req_idx + 1] -
|
||||||
|
1][1] = (slices_end[req_idx] - 1) % self.block_size + 1
|
||||||
|
slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
|
||||||
|
cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
|
||||||
|
np.cumsum(slice_lens, out=cu_slices_lens[1:])
|
||||||
|
kv_cache_start_indices = slot_mapping_slices[:, 0] + \
|
||||||
|
(block_numbers * self.block_size)
|
||||||
|
new_kv_start_indices = cu_slices_lens[:-1]
|
||||||
|
slot_mapping_metadata = np.stack(
|
||||||
|
[kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1)
|
||||||
|
return slot_mapping_metadata
|
||||||
|
|
||||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
|
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
|
||||||
start_index: int):
|
start_index: int):
|
||||||
assert scheduler_output.total_num_scheduled_tokens > 0
|
assert scheduler_output.total_num_scheduled_tokens > 0
|
||||||
@ -603,26 +665,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
torch.from_numpy(token_indices),
|
torch.from_numpy(token_indices),
|
||||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||||
|
|
||||||
# Calculate the slot mapping.
|
|
||||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
||||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
|
||||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
|
||||||
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
|
||||||
# because M (max_model_len) is not necessarily divisible by block_size.
|
|
||||||
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
||||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
|
||||||
positions_np // self.block_size)
|
|
||||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
||||||
# because torch.index_select is much faster than np.take for large
|
|
||||||
# tensors.
|
|
||||||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
|
||||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
|
||||||
block_offsets = positions_np % self.block_size
|
|
||||||
np.add(block_numbers * self.block_size,
|
|
||||||
block_offsets,
|
|
||||||
out=self.input_batch.block_table[0].
|
|
||||||
slot_mapping_np[:total_num_scheduled_tokens])
|
|
||||||
|
|
||||||
# Prepare the attention metadata.
|
# Prepare the attention metadata.
|
||||||
self.query_start_loc_np[0] = 0
|
self.query_start_loc_np[0] = 0
|
||||||
np.cumsum(num_scheduled_tokens_per_req,
|
np.cumsum(num_scheduled_tokens_per_req,
|
||||||
@ -645,12 +687,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.position_ids = self.positions_cpu[:
|
self.position_ids = self.positions_cpu[:
|
||||||
padded_total_num_scheduled_tokens].to(
|
padded_total_num_scheduled_tokens].to(
|
||||||
self.device)
|
self.device)
|
||||||
self.input_batch.block_table[0].slot_mapping_cpu[
|
|
||||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
|
||||||
slot_mapping = (
|
|
||||||
self.input_batch.block_table[0].
|
|
||||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
|
||||||
self.device))
|
|
||||||
if use_max_model_len:
|
if use_max_model_len:
|
||||||
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
|
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
|
||||||
self.max_num_blocks_per_req]
|
self.max_num_blocks_per_req]
|
||||||
@ -675,6 +711,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.device)
|
self.device)
|
||||||
block_tables = block_tables.to(self.device)
|
block_tables = block_tables.to(self.device)
|
||||||
|
|
||||||
|
slot_mapping_metadata = self._get_slot_mapping_metadata(
|
||||||
|
num_reqs, num_scheduled_tokens_per_req)
|
||||||
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
|
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||||
|
self.block_size)
|
||||||
|
slot_mapping_metadata = np.pad(
|
||||||
|
slot_mapping_metadata,
|
||||||
|
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
|
||||||
|
constant_values=0)
|
||||||
|
slot_mapping_metadata = np.transpose(slot_mapping_metadata)
|
||||||
|
slot_mapping_metadata = torch.tensor(slot_mapping_metadata,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
if self.lora_config is not None:
|
if self.lora_config is not None:
|
||||||
# We need to respect padding when activating LoRA adapters
|
# We need to respect padding when activating LoRA adapters
|
||||||
padded_num_scheduled_tokens_per_req = np.copy(
|
padded_num_scheduled_tokens_per_req = np.copy(
|
||||||
@ -687,13 +736,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
padded_num_scheduled_tokens_per_req)
|
padded_num_scheduled_tokens_per_req)
|
||||||
|
|
||||||
attn_metadata = PallasMetadata(
|
attn_metadata = PallasMetadata(
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping_metadata,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
context_lens=seq_lens,
|
context_lens=seq_lens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=torch.tensor([num_reqs],
|
num_seqs=torch.tensor([num_reqs],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
|
num_slices_per_kv_cache_update_block=
|
||||||
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||||
)
|
)
|
||||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||||
# request in the batch. While we should not sample any token from this
|
# request in the batch. While we should not sample any token from this
|
||||||
@ -1119,8 +1170,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
actual_num_reqs = min(num_tokens, num_reqs)
|
actual_num_reqs = min(num_tokens, num_reqs)
|
||||||
position_ids = torch.zeros(num_tokens,
|
position_ids = torch.zeros(num_tokens,
|
||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
slot_mapping = torch.zeros(num_tokens,
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
dtype=torch.int64).to(self.device)
|
num_tokens, self.max_num_reqs, self.block_size)
|
||||||
|
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||||
|
dtype=torch.int32).to(self.device)
|
||||||
block_tables = torch.zeros((num_reqs, num_blocks),
|
block_tables = torch.zeros((num_reqs, num_blocks),
|
||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
query_lens = [1] * num_reqs
|
query_lens = [1] * num_reqs
|
||||||
@ -1138,6 +1191,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
|
num_slices_per_kv_cache_update_block=
|
||||||
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
|||||||
return paddings[index]
|
return paddings[index]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
||||||
|
page_size: int) -> int:
|
||||||
|
"""Calculates the padded number of KV cache update slices to avoid
|
||||||
|
recompilation."""
|
||||||
|
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
|
||||||
|
padded_num_slices = min(padded_num_slices, num_tokens)
|
||||||
|
padded_num_slices = (
|
||||||
|
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
|
||||||
|
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
|
||||||
|
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
|
||||||
|
return padded_num_slices
|
||||||
|
|
||||||
|
|
||||||
def replace_set_lora(model):
|
def replace_set_lora(model):
|
||||||
|
|
||||||
def _tpu_set_lora(
|
def _tpu_set_lora(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user