mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 16:05:42 +08:00
[Feature][ROCm] Add full graph capture support for TritonAttentionBackend (#19158)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
b447624ee3
commit
a44b1c951d
@ -147,6 +147,7 @@ def test_lower_max_num_seqs(model, supported):
|
|||||||
llm.generate(["Hello, my name is"] * 10)
|
llm.generate(["Hello, my name is"] * 10)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||||
def test_full_cudagraph_with_invalid_backend():
|
def test_full_cudagraph_with_invalid_backend():
|
||||||
with temporary_environ({
|
with temporary_environ({
|
||||||
"VLLM_USE_V1": "1",
|
"VLLM_USE_V1": "1",
|
||||||
|
|||||||
@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (head_size == 64 or head_size == 128)
|
and (head_size == 64 or head_size == 128)
|
||||||
and (block_size == 16 or block_size == 32)
|
and (block_size == 16 or block_size == 32)
|
||||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and max_seq_len <= 128 * 1024
|
||||||
|
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
and envs.VLLM_ROCM_USE_AITER))
|
and envs.VLLM_ROCM_USE_AITER))
|
||||||
|
|
||||||
@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
and head_size == 128 and block_size == 16
|
and head_size == 128 and block_size == 16
|
||||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 32768 and alibi_slopes is None
|
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
||||||
and kv_cache_dtype == "auto"
|
and kv_cache_dtype == "auto"
|
||||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
|
|
||||||
|
|||||||
@ -19,9 +19,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata,
|
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||||
get_kv_cache_layout)
|
make_local_attention_virtual_batches)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
@ -126,172 +126,6 @@ class FlashAttentionMetadata:
|
|||||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
|
||||||
# local attention blocks, where each block is passed to the attention kernel
|
|
||||||
# as an independent local ("virtual") batch item.
|
|
||||||
#
|
|
||||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
|
||||||
# q_seqlens = [4, 10, 5]
|
|
||||||
# kv_seqlens = [6, 17, 9]
|
|
||||||
# Then normally for regular attention we would compute with an attention mask
|
|
||||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
|
||||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
|
||||||
# k_toks > 0 1 2 3 4 5
|
|
||||||
# q_toks v _____________
|
|
||||||
# 0 | 1 1 1
|
|
||||||
# 1 | 1 1 1 1
|
|
||||||
# 2 | 1 1 1 1 1
|
|
||||||
# 3 | 1 1 1 1 1 1
|
|
||||||
#
|
|
||||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
|
||||||
# attention mask like:
|
|
||||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
|
||||||
# k_toks > 0 1 2 3 4 5
|
|
||||||
# q_toks v _____________
|
|
||||||
# 0 | 1 1 1
|
|
||||||
# 1 | 1 1 1 1
|
|
||||||
# 2 | 1
|
|
||||||
# 3 | 1 1
|
|
||||||
#
|
|
||||||
# We can simulate this mask using standard flash-attention by breaking the
|
|
||||||
# sequences into local ("virtual") batches, where each local batch item is a
|
|
||||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
|
||||||
#
|
|
||||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
|
||||||
# k_toks > 0 1 2 3
|
|
||||||
# q_toks v _____________
|
|
||||||
# 0 | 1 1 1
|
|
||||||
# 1 | 1 1 1 1
|
|
||||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
|
||||||
# k_toks > 4 5
|
|
||||||
# q_toks v _____________
|
|
||||||
# 2 | 1
|
|
||||||
# 3 | 1 1
|
|
||||||
#
|
|
||||||
# e.g. if we have:
|
|
||||||
# attn_chunk_size = 4
|
|
||||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
|
||||||
# Then this function would return:
|
|
||||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
|
||||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
|
||||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
|
||||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
|
||||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
|
||||||
def make_local_attention_virtual_batches(
|
|
||||||
attn_chunk_size: int,
|
|
||||||
query_start_loc_np: np.ndarray,
|
|
||||||
seq_lens_np: np.ndarray,
|
|
||||||
block_table: torch.Tensor,
|
|
||||||
block_size: int = 0,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
|
||||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
|
||||||
actual_batch_size = seq_lens_np.shape[0]
|
|
||||||
|
|
||||||
# Handle if we are starting in the middle of a local attention block,
|
|
||||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
|
||||||
# the number of tokens that are not in the first local attention block and
|
|
||||||
# then we can simply use a cdiv for the rest.
|
|
||||||
# For example if we have:
|
|
||||||
# attn_chunk_size = 4
|
|
||||||
# q_seqlens = [4, 10, 5]
|
|
||||||
# k_seqlens = [6, 17, 9]
|
|
||||||
# Then we would get:
|
|
||||||
# new_tokens_in_first_block = [2, 1, 4]
|
|
||||||
# local_blocks = [2, 4, 2]
|
|
||||||
q_tokens_in_first_block = np.minimum(
|
|
||||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
|
||||||
q_seqlens).astype(np.int32)
|
|
||||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
|
||||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
|
||||||
attn_chunk_size)
|
|
||||||
|
|
||||||
# Once we know the number of local blocks we can compute the request spans
|
|
||||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
|
||||||
# have to make,
|
|
||||||
# For the above example we would get:
|
|
||||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
|
||||||
#
|
|
||||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
|
||||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
|
||||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
|
||||||
cu_num_blocks = np.cumsum(local_blocks)
|
|
||||||
virtual_batches = cu_num_blocks[-1]
|
|
||||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
|
||||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
|
||||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
|
||||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
|
||||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
|
||||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
|
||||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
|
||||||
# first and last blocks could be partial
|
|
||||||
seqlens_q_local = \
|
|
||||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
|
||||||
# set the first block since this may be a partial block
|
|
||||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
|
||||||
# set the remaining blocks
|
|
||||||
seqlens_q_local[arange > 0] = np.minimum(
|
|
||||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
|
||||||
attn_chunk_size)[arange > 0]
|
|
||||||
|
|
||||||
# convert from q_seqlens to cu_seqlens_q
|
|
||||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
|
||||||
.astype(np.int32)
|
|
||||||
|
|
||||||
# compute the seqlens_k_local,
|
|
||||||
# basically a full local attention block for all but the last block in each
|
|
||||||
# batch
|
|
||||||
# For our example this will be:
|
|
||||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
|
||||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
|
||||||
attn_chunk_size,
|
|
||||||
dtype=np.int32)
|
|
||||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
|
||||||
|
|
||||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
|
||||||
(rarange * attn_chunk_size + \
|
|
||||||
np.repeat(tokens_in_last_block, local_blocks))
|
|
||||||
# For the example the local attention blocks start at:
|
|
||||||
# _b0_ _____b1_____ _b2_
|
|
||||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
|
||||||
block_starts = k_seqstarts_absolute // block_size
|
|
||||||
assert attn_chunk_size % block_size == 0, \
|
|
||||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
|
||||||
f"divisible by block_size {block_size}"
|
|
||||||
pages_per_local_batch = attn_chunk_size // block_size
|
|
||||||
|
|
||||||
# Create a block_table for the local attention blocks
|
|
||||||
# For out example if we have a block-table like (assuming block_size=2):
|
|
||||||
# block_table = [
|
|
||||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
|
||||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
|
||||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
|
||||||
# ]
|
|
||||||
# Then for the local batches we would want a block-table like
|
|
||||||
# block_table_local = [
|
|
||||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
|
||||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
|
||||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
|
||||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
|
||||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
|
||||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
|
||||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
|
||||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
|
||||||
# ]
|
|
||||||
block_indices= np.broadcast_to(
|
|
||||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
|
||||||
(virtual_batches, pages_per_local_batch)) \
|
|
||||||
+ np.expand_dims(block_starts, axis=1)
|
|
||||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
|
||||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
|
||||||
local_blocks * pages_per_local_batch)
|
|
||||||
block_table_local = block_table[batch_indices, block_indices]\
|
|
||||||
.view(virtual_batches, -1)
|
|
||||||
|
|
||||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
|
||||||
block_table_local
|
|
||||||
|
|
||||||
|
|
||||||
def _get_sliding_window_configs(
|
def _get_sliding_window_configs(
|
||||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||||
"""Get the set of all sliding window configs used in the model."""
|
"""Get the set of all sliding window configs used in the model."""
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -15,8 +16,10 @@ from vllm.attention.ops.paged_attn import PagedAttention
|
|||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
|
make_local_attention_virtual_batches)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
@ -26,12 +29,161 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
@dataclass
|
||||||
|
class TritonAttentionMetadata:
|
||||||
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
|
# |---------- N-1 iteration --------|
|
||||||
|
# |---------------- N iteration ---------------------|
|
||||||
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
|
# |---------- context_len ----------|
|
||||||
|
# |-------------------- seq_len ---------------------|
|
||||||
|
# |-- query_len ---|
|
||||||
|
|
||||||
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
|
max_query_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
max_seq_len: int
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
# For cascade attention.
|
||||||
|
use_cascade: bool
|
||||||
|
common_prefix_len: int
|
||||||
|
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||||
|
prefix_kv_lens: Optional[torch.Tensor]
|
||||||
|
suffix_kv_lens: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Optional aot scheduling
|
||||||
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# for local attention
|
||||||
|
@dataclass
|
||||||
|
class LocalAttentionMetadata:
|
||||||
|
local_query_start_loc: torch.Tensor
|
||||||
|
local_seqused_k: torch.Tensor
|
||||||
|
local_block_table: torch.Tensor
|
||||||
|
local_max_query_len: int
|
||||||
|
local_max_seq_len: int
|
||||||
|
local_scheduler_metadata: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TritonAttentionMetadataBuilder(
|
||||||
|
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
|
full_cudagraph_supported: ClassVar[bool] = True
|
||||||
|
|
||||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||||
block_table: BlockTable):
|
block_table: BlockTable):
|
||||||
super().__init__(runner, kv_cache_spec, block_table)
|
self.runner = runner
|
||||||
self.aot_schedule = False
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.block_table = block_table
|
||||||
|
|
||||||
|
def build_for_cudagraph_capture(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
|
) -> TritonAttentionMetadata:
|
||||||
|
attn_metadata = self.build(0, common_attn_metadata)
|
||||||
|
# When doing full graph capture, setting seq_lens to
|
||||||
|
# max_model_len will cause graph capture to be extremely
|
||||||
|
# slow, so here we set it to 1.
|
||||||
|
attn_metadata.seq_lens.fill_(1)
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self, common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata
|
||||||
|
) -> TritonAttentionMetadata:
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
|
||||||
|
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
block_table = self.block_table
|
||||||
|
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||||
|
|
||||||
|
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||||
|
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||||
|
non_blocking=True)
|
||||||
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||||
|
# mode.
|
||||||
|
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||||
|
|
||||||
|
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||||
|
|
||||||
|
# for local attention
|
||||||
|
local_attn_metadata = None
|
||||||
|
if self.runner.attention_chunk_size is not None:
|
||||||
|
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||||
|
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||||
|
self.runner.attention_chunk_size,
|
||||||
|
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||||
|
self.runner.seq_lens_np[:num_reqs],
|
||||||
|
block_table_tensor,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||||
|
self.runner.device, non_blocking=True)
|
||||||
|
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
|
self.runner.device, non_blocking=True)
|
||||||
|
local_max_query_len = seqlens_q_local_np.max()
|
||||||
|
local_max_seq_len = virt_k_seqlens_np.max()
|
||||||
|
|
||||||
|
local_attn_metadata = TritonAttentionMetadata \
|
||||||
|
.LocalAttentionMetadata(
|
||||||
|
local_query_start_loc=local_query_start_loc,
|
||||||
|
local_seqused_k=local_seqused_k,
|
||||||
|
local_block_table=virt_block_table_tensor,
|
||||||
|
local_max_query_len=local_max_query_len,
|
||||||
|
local_max_seq_len=local_max_seq_len,
|
||||||
|
local_scheduler_metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
|
if use_cascade:
|
||||||
|
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device)
|
||||||
|
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device)
|
||||||
|
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||||
|
common_prefix_len)
|
||||||
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||||
|
self.runner.device)
|
||||||
|
else:
|
||||||
|
cu_prefix_query_lens = None
|
||||||
|
prefix_kv_lens = None
|
||||||
|
suffix_kv_lens = None
|
||||||
|
prefix_scheduler_metadata = None
|
||||||
|
|
||||||
|
attn_metadata = TritonAttentionMetadata(
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
block_table=block_table_tensor,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
use_cascade=use_cascade,
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||||
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
|
local_attn_metadata=local_attn_metadata,
|
||||||
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
def can_run_in_cudagraph(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||||
|
# Full CUDA Graph always supported
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionBackend(AttentionBackend):
|
class TritonAttentionBackend(AttentionBackend):
|
||||||
@ -52,7 +204,7 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
return FlashAttentionMetadata
|
return TritonAttentionMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
@ -140,3 +142,169 @@ def get_kv_cache_layout():
|
|||||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||||
|
|
||||||
return cache_layout
|
return cache_layout
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||||
|
# local attention blocks, where each block is passed to the attention kernel
|
||||||
|
# as an independent local ("virtual") batch item.
|
||||||
|
#
|
||||||
|
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# kv_seqlens = [6, 17, 9]
|
||||||
|
# Then normally for regular attention we would compute with an attention mask
|
||||||
|
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||||
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||||
|
# k_toks > 0 1 2 3 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# 2 | 1 1 1 1 1
|
||||||
|
# 3 | 1 1 1 1 1 1
|
||||||
|
#
|
||||||
|
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||||
|
# attention mask like:
|
||||||
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||||
|
# k_toks > 0 1 2 3 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# 2 | 1
|
||||||
|
# 3 | 1 1
|
||||||
|
#
|
||||||
|
# We can simulate this mask using standard flash-attention by breaking the
|
||||||
|
# sequences into local ("virtual") batches, where each local batch item is a
|
||||||
|
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||||
|
#
|
||||||
|
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||||
|
# k_toks > 0 1 2 3
|
||||||
|
# q_toks v _____________
|
||||||
|
# 0 | 1 1 1
|
||||||
|
# 1 | 1 1 1 1
|
||||||
|
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||||
|
# k_toks > 4 5
|
||||||
|
# q_toks v _____________
|
||||||
|
# 2 | 1
|
||||||
|
# 3 | 1 1
|
||||||
|
#
|
||||||
|
# e.g. if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||||
|
# Then this function would return:
|
||||||
|
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||||
|
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||||
|
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||||
|
def make_local_attention_virtual_batches(
|
||||||
|
attn_chunk_size: int,
|
||||||
|
query_start_loc_np: np.ndarray,
|
||||||
|
seq_lens_np: np.ndarray,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
block_size: int = 0,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||||
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
|
||||||
|
# Handle if we are starting in the middle of a local attention block,
|
||||||
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||||
|
# the number of tokens that are not in the first local attention block and
|
||||||
|
# then we can simply use a cdiv for the rest.
|
||||||
|
# For example if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# k_seqlens = [6, 17, 9]
|
||||||
|
# Then we would get:
|
||||||
|
# new_tokens_in_first_block = [2, 1, 4]
|
||||||
|
# local_blocks = [2, 4, 2]
|
||||||
|
q_tokens_in_first_block = np.minimum(
|
||||||
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||||
|
q_seqlens).astype(np.int32)
|
||||||
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||||
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||||
|
attn_chunk_size)
|
||||||
|
|
||||||
|
# Once we know the number of local blocks we can compute the request spans
|
||||||
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||||
|
# have to make,
|
||||||
|
# For the above example we would get:
|
||||||
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
#
|
||||||
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||||
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||||
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||||
|
cu_num_blocks = np.cumsum(local_blocks)
|
||||||
|
virtual_batches = cu_num_blocks[-1]
|
||||||
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||||
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||||
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||||
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||||
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||||
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||||
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||||
|
# first and last blocks could be partial
|
||||||
|
seqlens_q_local = \
|
||||||
|
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||||
|
# set the first block since this may be a partial block
|
||||||
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||||
|
# set the remaining blocks
|
||||||
|
seqlens_q_local[arange > 0] = np.minimum(
|
||||||
|
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||||
|
attn_chunk_size)[arange > 0]
|
||||||
|
|
||||||
|
# convert from q_seqlens to cu_seqlens_q
|
||||||
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||||
|
.astype(np.int32)
|
||||||
|
|
||||||
|
# compute the seqlens_k_local,
|
||||||
|
# basically a full local attention block for all but the last block in each
|
||||||
|
# batch
|
||||||
|
# For our example this will be:
|
||||||
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||||
|
attn_chunk_size,
|
||||||
|
dtype=np.int32)
|
||||||
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||||
|
|
||||||
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||||
|
(rarange * attn_chunk_size + \
|
||||||
|
np.repeat(tokens_in_last_block, local_blocks))
|
||||||
|
# For the example the local attention blocks start at:
|
||||||
|
# _b0_ _____b1_____ _b2_
|
||||||
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||||
|
block_starts = k_seqstarts_absolute // block_size
|
||||||
|
assert attn_chunk_size % block_size == 0, \
|
||||||
|
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||||
|
f"divisible by block_size {block_size}"
|
||||||
|
pages_per_local_batch = attn_chunk_size // block_size
|
||||||
|
|
||||||
|
# Create a block_table for the local attention blocks
|
||||||
|
# For out example if we have a block-table like (assuming block_size=2):
|
||||||
|
# block_table = [
|
||||||
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||||
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||||
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||||
|
# ]
|
||||||
|
# Then for the local batches we would want a block-table like
|
||||||
|
# block_table_local = [
|
||||||
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||||
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||||
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||||
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||||
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||||
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||||
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||||
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||||
|
# ]
|
||||||
|
block_indices= np.broadcast_to(
|
||||||
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
|
(virtual_batches, pages_per_local_batch)) \
|
||||||
|
+ np.expand_dims(block_starts, axis=1)
|
||||||
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||||
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||||
|
local_blocks * pages_per_local_batch)
|
||||||
|
block_table_local = block_table[batch_indices, block_indices]\
|
||||||
|
.view(virtual_batches, -1)
|
||||||
|
|
||||||
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||||
|
block_table_local
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user