mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Feature][ROCm] Add full graph capture support for TritonAttentionBackend (#19158)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
b447624ee3
commit
a44b1c951d
@ -147,6 +147,7 @@ def test_lower_max_num_seqs(model, supported):
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
|
||||
@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention(
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
|
||||
@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128 and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and alibi_slopes is None
|
||||
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
||||
|
||||
@ -19,9 +19,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -126,172 +126,6 @@ class FlashAttentionMetadata:
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||
q_seqlens).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||
attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = \
|
||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||
attn_chunk_size)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
np.repeat(tokens_in_last_block, local_blocks))
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // block_size
|
||||
assert attn_chunk_size % block_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by block_size {block_size}"
|
||||
pages_per_local_batch = attn_chunk_size // block_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming block_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices= np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch)) \
|
||||
+ np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch)
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
"""Get the set of all sliding window configs used in the model."""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,8 +16,10 @@ from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -26,12 +29,161 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
self.aot_schedule = False
|
||||
self.runner = runner
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
|
||||
local_attn_metadata = TritonAttentionMetadata \
|
||||
.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=None,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = TritonAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
@ -52,7 +204,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
return TritonAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
@ -140,3 +142,169 @@ def get_kv_cache_layout():
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
|
||||
return cache_layout
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||
q_seqlens).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||
attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = \
|
||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||
attn_chunk_size)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
np.repeat(tokens_in_last_block, local_blocks))
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // block_size
|
||||
assert attn_chunk_size % block_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by block_size {block_size}"
|
||||
pages_per_local_batch = attn_chunk_size // block_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming block_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices= np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch)) \
|
||||
+ np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch)
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user