diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index c1f5d9658af16..efe9c843f144c 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -147,6 +147,7 @@ def test_lower_max_num_seqs(model, supported): llm.generate(["Hello, my name is"] * 10) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): with temporary_environ({ "VLLM_USE_V1": "1", diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a929366db49cc..445b24d72f078 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention( and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and max_seq_len <= 128 * 1024 + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)) @@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention( and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 32768 and alibi_slopes is None + and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 43a664476aaae..4ad7178374b1a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,9 +19,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, + make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -126,172 +126,6 @@ class FlashAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None -# -# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into -# local attention blocks, where each block is passed to the attention kernel -# as an independent local ("virtual") batch item. -# -# For example, if are performing a chunked prefill a batch of 3 sequences: -# q_seqlens = [4, 10, 5] -# kv_seqlens = [6, 17, 9] -# Then normally for regular attention we would compute with an attention mask -# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) -# k_toks > 0 1 2 3 4 5 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# 2 | 1 1 1 1 1 -# 3 | 1 1 1 1 1 1 -# -# for local attention (with attn_chunk_size = 4) we would compute with an -# attention mask like: -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) -# k_toks > 0 1 2 3 4 5 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# 2 | 1 -# 3 | 1 1 -# -# We can simulate this mask using standard flash-attention by breaking the -# sequences into local ("virtual") batches, where each local batch item is a -# local attention block, so in this case batch idx 0 would be broken up into: -# -# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) -# k_toks > 0 1 2 3 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) -# k_toks > 4 5 -# q_toks v _____________ -# 2 | 1 -# 3 | 1 1 -# -# e.g. if we have: -# attn_chunk_size = 4 -# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) -# Then this function would return: -# __b0__ ______b1______ __b2__ < orig batch indices -# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] -# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] -# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] -# block_table_local : shape[local_virtual_batches, pages_per_local_batch] -def make_local_attention_virtual_batches( - attn_chunk_size: int, - query_start_loc_np: np.ndarray, - seq_lens_np: np.ndarray, - block_table: torch.Tensor, - block_size: int = 0, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: - q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] - actual_batch_size = seq_lens_np.shape[0] - - # Handle if we are starting in the middle of a local attention block, - # we assume q_seqlens > 0 (for all elements), for each batch idx we compute - # the number of tokens that are not in the first local attention block and - # then we can simply use a cdiv for the rest. - # For example if we have: - # attn_chunk_size = 4 - # q_seqlens = [4, 10, 5] - # k_seqlens = [6, 17, 9] - # Then we would get: - # new_tokens_in_first_block = [2, 1, 4] - # local_blocks = [2, 4, 2] - q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) - tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) - - # Once we know the number of local blocks we can compute the request spans - # for each batch idx, we can figure out the number of "virtual" requests we - # have to make, - # For the above example we would get: - # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] - # - # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) - # (TODO: max a utility to share this code with _prepare_inputs) - # arange step 1. [2, 4, 2] -> [2, 6, 8] - cu_num_blocks = np.cumsum(local_blocks) - virtual_batches = cu_num_blocks[-1] - # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] - block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) - # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] - arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets - # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) - rarange = np.repeat(local_blocks, local_blocks) - arange - 1 - # Then we can compute the seqlens_q_local, handling the fact that the - # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) - # set the first block since this may be a partial block - seqlens_q_local[arange == 0] = q_tokens_in_first_block - # set the remaining blocks - seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] - - # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) - - # compute the seqlens_k_local, - # basically a full local attention block for all but the last block in each - # batch - # For our example this will be: - # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) - seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block - - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) - # For the example the local attention blocks start at: - # _b0_ _____b1_____ _b2_ - # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] - block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" - pages_per_local_batch = attn_chunk_size // block_size - - # Create a block_table for the local attention blocks - # For out example if we have a block-table like (assuming block_size=2): - # block_table = [ - # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 - # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 - # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 - # ] - # Then for the local batches we would want a block-table like - # block_table_local = [ - # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) - # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) - # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) - # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) - # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) - # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) - # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) - # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) - # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) - block_table_local = block_table[batch_indices, block_indices]\ - .view(virtual_batches, -1) - - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ - block_table_local - - def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: """Get the set of all sliding window configs used in the model.""" diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9782ec087babb..ecb92bb1e4161 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" -from typing import TYPE_CHECKING, Any, Optional +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch @@ -15,8 +16,10 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -26,12 +29,161 @@ if TYPE_CHECKING: logger = init_logger(__name__) -class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): +@dataclass +class TritonAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +class TritonAttentionMetadataBuilder( + AttentionMetadataBuilder[TritonAttentionMetadata]): + full_cudagraph_supported: ClassVar[bool] = True def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table) - self.aot_schedule = False + self.runner = runner + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ + virt_block_table_tensor = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + + local_attn_metadata = TritonAttentionMetadata \ + .LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=None, + ) + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.runner.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = TritonAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, + ) + return attn_metadata + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + # Full CUDA Graph always supported + return True class TritonAttentionBackend(AttentionBackend): @@ -52,7 +204,7 @@ class TritonAttentionBackend(AttentionBackend): @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: - return FlashAttentionMetadata + return TritonAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 82798afee32cb..8083f20026024 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar import numpy as np import torch +from vllm.utils import cdiv + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -140,3 +142,169 @@ def get_kv_cache_layout(): "detected. Setting KV cache layout to %s.", cache_layout) return cache_layout + + +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + block_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), + q_seqlens).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, + attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = \ + np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), + attn_chunk_size)[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ + .astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], + attn_chunk_size, + dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ + (rarange * attn_chunk_size + \ + np.repeat(tokens_in_last_block, local_blocks)) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // block_size + assert attn_chunk_size % block_size == 0, \ + f"attn_chunk_size {attn_chunk_size} is not " \ + f"divisible by block_size {block_size}" + pages_per_local_batch = attn_chunk_size // block_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming block_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices= np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch)) \ + + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch) + block_table_local = block_table[batch_indices, block_indices]\ + .view(virtual_batches, -1) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ + block_table_local