[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
Lingfan Yu 2025-02-20 17:45:45 -08:00 committed by GitHub
parent 71face8540
commit 33170081f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 769 additions and 353 deletions

View File

@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
import os
import neuronxcc.nki.language as nl
import pytest
import torch
import torch.nn.functional as F
from neuronxcc import nki
from vllm.attention.ops.nki_flash_attn import (
load_block_tables, transform_block_tables_for_indirect_load)
def is_power_of_2(n):
return n > 0 and (n & (n - 1) == 0)
def nki_load_and_transform_block_tables(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
num_blocks_per_tile)
# we need to pass an Index as head_id
head_id = nl.arange(1)[None, :] + head_id
block_tables_transposed = transform_block_tables_for_indirect_load(
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
B_P_SIZE = 128
assert block_tables_transposed.shape[1] == B_P_SIZE
out = nl.ndarray(
block_tables_transposed.shape,
dtype=nl.int32,
buffer=nl.shared_hbm,
)
for i in nl.affine_range(block_tables_transposed.shape[0]):
nl.store(dst=out[i], value=block_tables_transposed[i])
return out
def ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert block_tables.numel() == num_tiles * num_blocks_per_tile
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
B_F_SIZE = 128
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
block_tables = F.pad(
block_tables,
(0, 0, 0, num_tiles_padded - num_tiles),
"constant",
0,
)
block_tables = block_tables * num_head + head_id
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
block_tables = block_tables * block_size_tiling_factor + offset
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
num_blocks_per_tile = block_tables_transposed.shape[0]
assert num_blocks_per_tile % B_F_SIZE == 0
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
B_F_SIZE, num_tiles_padded)
@pytest.mark.parametrize(
"q_head_per_kv_head,head_id",
[
(1, 0),
(3, 1),
],
)
@pytest.mark.parametrize(
"num_tiles,num_blocks_per_tile",
[
(1, 1),
(13, 16),
(17, 128),
(35, 512),
(128, 128),
(130, 64),
(280, 256),
(315, 1),
],
)
@torch.inference_mode()
def test_load_and_transform_block_tables(
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
compiler_flags = [
"-O1",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE = 128
if num_blocks_per_tile < B_P_SIZE:
assert B_P_SIZE % num_blocks_per_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
else:
block_size_tiling_factor = 1
max_num_blocks = 100000
block_tables = torch.randint(
0,
max_num_blocks,
(num_tiles * num_blocks_per_tile, ),
dtype=torch.int32,
)
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
block_tables.to(device=device),
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
).cpu()
ref_out = ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
)
assert (nki_out.shape == ref_out.shape
), f"{nki_out.shape=} != {ref_out.shape=}"
assert torch.all(nki_out == ref_out)

View File

@ -107,7 +107,7 @@ def ref_masked_attention(
masked_score, dim=-1, return_max_reduce=True)
else:
norm_score = ref_softmax(masked_score, dim=-1)
out = torch.einsum("hqk,khd->qhd", norm_score, value)
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
if return_max_reduce:
return (
out,
@ -118,7 +118,7 @@ def ref_masked_attention(
scaled_qk,
)
else:
return out
return (out, )
def ref_context_attention(
@ -128,8 +128,6 @@ def ref_context_attention(
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=False,
):
@ -146,18 +144,19 @@ def ref_context_attention(
attn_mask = torch.logical_not(attn_mask)
attn_mask = attn_mask.float() * -30000
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
))
output, *debug_tensors = ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
)
output = output.unsqueeze(1)
if return_max_reduce:
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
debug_tensors)
return (
output,
cached_max,
@ -170,65 +169,22 @@ def ref_context_attention(
return output
@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
assert large_tile_size % block_size == 0
device = xm.xla_device()
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
def sample_inputs(
prefill_batch_size,
decode_batch_size,
min_query_len,
max_query_len,
min_ctx_len,
max_ctx_len,
block_size,
num_heads,
num_kv_heads,
head_size,
dtype,
):
batch_size = prefill_batch_size + decode_batch_size
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
@ -244,7 +200,6 @@ def test_contexted_kv_attention(
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
@ -304,47 +259,139 @@ def test_contexted_kv_attention(
cur_ctx += block_size
block_id += 1
return (
query,
k,
v,
k_cache,
v_cache,
block_table,
key,
value,
query_lens,
seq_lens,
)
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks, dtype=torch.int32),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
@pytest.mark.parametrize(
"prefill_batch_size,decode_batch_size,block_size,large_tile_size",
[
(1, 199, 1, 512), # 512 blocks
(4, 12, 256, 2048), # 128 blocks
(4, 12, 16, 2048), # 128 blocks
(4, 12, 4, 1024), # 256 blocks
(4, 12, 32, 2048), # 64 blocks
(4, 12, 32, 4096), # 128 blocks
(4, 12, 32, 8192), # 256 blocks
(4, 12, 64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size",
[
(4, 2, 8),
(32, 8, 64),
(4, 4, 128),
(8, 1, 32),
],
)
@pytest.mark.parametrize("mixed_precision", [True, False])
@torch.inference_mode()
def test_contexted_kv_attention(
prefill_batch_size: int,
decode_batch_size: int,
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
reorder_context_mask)
assert large_tile_size % block_size == 0
device = xm.xla_device()
compiler_flags = [
"-O1",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
dtype = torch.float32
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
num_kv_heads = num_heads // num_queries_per_kv
(
output_ref,
cached_max,
cached_sum_reciprocal,
lse,
masked_score,
scaled_qk,
) = ref_context_attention(
query,
k_active,
v_active,
k_cache,
v_cache,
block_table,
key,
value,
query_lens,
seq_lens,
) = sample_inputs(
prefill_batch_size=prefill_batch_size,
decode_batch_size=decode_batch_size,
min_query_len=min_query_len,
max_query_len=max_query_len,
min_ctx_len=min_ctx_len,
max_ctx_len=max_ctx_len,
block_size=block_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
)
output_ref = ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=True,
return_max_reduce=False,
)
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = large_tile_size
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
assert (large_tile_size >= B_P_SIZE
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
def ceil_div(a, b):
return (a + b - 1) // b
@ -357,32 +404,27 @@ def test_contexted_kv_attention(
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
0,
head_size_padded - query.shape[2],
0,
0,
0,
0,
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
v = F.pad(v, pad_dims, "constant", 0)
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
k = F.pad(k_active, pad_dims, "constant", 0)
v = F.pad(v_active, pad_dims, "constant", 0)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
@ -391,6 +433,8 @@ def test_contexted_kv_attention(
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
# transform block table
active_block_table = get_active_block_tables(
@ -405,33 +449,31 @@ def test_contexted_kv_attention(
prior_mask, active_mask = (
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size))
attn_mask = torch.concat(
[
F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool(),
F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool(),
],
dim=1,
)
prior_mask_padded = F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool()
active_mask_padded = F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size)
input_args = (
query.to(device=device),
@ -439,29 +481,21 @@ def test_contexted_kv_attention(
v.to(device=device),
k_cache.to(device=device),
v_cache.to(device=device),
active_block_table.to(torch.int32).to(device=device),
active_block_table.to(device=device),
attn_mask.to(device=device),
)
input_kwargs = dict(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
LARGE_TILE_SZ=large_tile_size,
)
if return_debug_tensors:
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
*input_args, **input_kwargs)
else:
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
num_actual_tokens = sum(query_lens)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,

View File

@ -1,27 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
@dataclass(frozen=True)
class FlashConfig:
"""
Config class for flash attention with default values
"""
def ceil_div(a, b):
return (a + b - 1) // b
seq_tile_size: int = 2048
should_transpose_v: bool = False
__annotations__ = {
"seq_tile_size": int,
"should_transpose_v": bool,
}
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
"""
Load block tables from HBM into SRAM
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
"""
B_P_SIZE = 128
# reshape as `(num_tiles, num_blocks_per_tile)`
assert len(block_tables_hbm.shape) == 1
(num_total_blocks, ) = block_tables_hbm.shape
assert num_blocks_per_tile * num_tiles == num_total_blocks
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros(
(ceil_div(num_tiles,
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
dtype=nl.int32,
mask=(i_p + i * B_P_SIZE < num_tiles),
)
return block_tables_sbuf
@nki.jit
def transform_block_tables_for_indirect_load(
block_tables,
block_size_tiling_factor,
num_head,
head_id,
):
"""
This function does two things:
1. calculate new `block_tables` for a `head_id` after flattening
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
2. transpose the result so that `block_table` for each tile is mapped to
SBUF Partition dimension for vectorized DMA
Tiling trick to further improve DMA performance:
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
blocks of a given `head_id` from HBM, the load `cache[block_tables,
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
fully utilize hardware parallelization. The solution is to tile `block_size`
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
Note:
We don't further tile D dimension as small DMA size also hurts performance.
"""
B_P_SIZE = 128
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
block_tables.shape)
assert num_tiles_per_partition == B_P_SIZE
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
par_dim(B_P_SIZE),
num_partitions * num_tiles_per_partition,
),
dtype=nl.int32,
)
# prepare iota ahead of time to avoid repeatedly using Gpsimd
if num_head > 1:
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
head_id = nl.transpose(
head_id.broadcast_to((1, num_tiles_per_partition)))
if num_blocks_per_tile > 1:
head_id = head_id.broadcast_to(
(num_tiles_per_partition, num_blocks_per_tile))
if block_size_tiling_factor > 1:
broadcast_shape = (
num_tiles_per_partition,
num_blocks_per_tile,
block_size_tiling_factor,
)
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
dtype=nl.int32).broadcast_to(broadcast_shape)
for partition_id in nl.affine_range(num_partitions):
block_tables_partition = block_tables[partition_id]
if num_head > 1:
# fuse num_block and num_head dimension
block_tables_partition = block_tables_partition * num_head + head_id
if block_size_tiling_factor > 1:
# need to apply block size tiling trick
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
block_tables_partition = ((block_tables_partition *
block_size_tiling_factor).reshape(
(num_tiles_per_partition,
num_blocks_per_tile,
1)).broadcast_to(broadcast_shape))
new_block_tables = block_tables_partition + offset
new_block_tables = new_block_tables.reshape(
(num_tiles_per_partition, B_P_SIZE))
else:
new_block_tables = block_tables_partition
# transpose the block table so that it can be used by vector DGE
for i in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = (partition_id * num_tiles_per_partition +
nl.arange(num_tiles_per_partition)[None, :])
block_tables_transposed[i, i_p, i_f] = nl.transpose(
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
return block_tables_transposed
@nki.jit
def load_kv_tile_from_cache(
cur_k_tile,
cur_v_tile,
key_cache,
value_cache,
block_tables,
large_k_tile_idx,
num_blocks_per_large_tile,
tiled_block_size,
B_P_SIZE,
B_D_SIZE,
):
"""
Load KV cache and transform Key and Value into layout required by Matmul
Vectorized DMA Load layout:
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
Layout used by attention matmuls:
Key: (par_dim(B_D_SIZE), seqlen_kv)
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE
for tb_i in nl.affine_range(tiled_block_size):
cur_k_tile[
:,
nl.ds(
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
B_P_SIZE,
),
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
# load value cache
for load_idx in nl.affine_range(num_loads):
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
cur_v_tile[
:,
nl.ds(
load_idx * tiled_block_size * B_D_SIZE,
tiled_block_size * B_D_SIZE,
),
] = loaded
@nki.jit
@ -62,13 +238,13 @@ def _flash_attention_core(
o_buffer,
l_buffer,
m_buffer,
q_tile_idx,
kernel_dtype,
acc_type,
flash_config: FlashConfig,
use_causal_mask,
tile_mask,
use_causal_mask,
q_tile_idx=None,
initialize=False,
LARGE_TILE_SZ=2048,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
@ -77,19 +253,19 @@ def _flash_attention_core(
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
already. The block size of K and V
is defined in the seq_tile_size of the flash_config. The results are stored
in the following three buffers
The q_local_tile has (B_P_SIZE, B_D_SIZE)
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
be split into size B_F_SIZE tiles
The results are stored in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
All IO buffers are in SBUF.
"""
LARGE_TILE_SZ = flash_config.seq_tile_size
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arithmetic intensity by half
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
@ -99,6 +275,8 @@ def _flash_attention_core(
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
if use_causal_mask:
# mask are used to only apply computation to the lower half of the
# matrix, which reduce the arithmetic intensity by up to 50%
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
@ -165,7 +343,9 @@ def _flash_attention_core(
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
p_partial_sum = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
dtype=acc_type,
)
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
@ -194,13 +374,15 @@ def _flash_attention_core(
B_F_SIZE=B_F_SIZE,
)
pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum)
pv_psum = nl.zeros(
(par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum,
)
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
pv_psum[:, :] += nl.matmul(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[k_i, :, :],
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
transpose_x=True,
) # (128, 128) (p(Br), d)
@ -219,44 +401,16 @@ def _flash_attention_core(
@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
LARGE_TILE_SZ = config.seq_tile_size
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
B_P_SIZE = 128
if not config.should_transpose_v:
cur_v_tile[v_i, :, :] = nl.load(
v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
dtype=cur_v_tile.dtype,
)
return
if nisa.get_nc_version() == nisa.nc_version.gen3:
cur_v_tile_transposed = nisa.dma_transpose(
v_hbm_tile[:,
nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
dtype=cur_v_tile.dtype)
return
cur_v_tile[v_i, :, :] = nl.load_transpose2d(
v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
dtype=cur_v_tile.dtype,
)
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles):
(num_blocks, ) = block_tables_hbm.shape
assert num_blocks % num_tiles == 0
num_blocks_per_tile = num_blocks // num_tiles
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
return block_tables_buffer
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
B_D_SIZE = v_hbm_tile.shape[-1]
loaded = nl.load(v_hbm_tile[
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
:,
])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
@nki.jit
@ -270,24 +424,21 @@ def flash_paged_attention(
mask,
softmax_scale=None,
mixed_precision=True,
config=None,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
):
"""
Flash PagedAttention Forward Kernel.
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
- Chunked Prefill Paper: https://arxiv.org/abs/2403.02310
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, block_size, n_kv_heads, d)
- value_cache: (num_blocks, block_size, n_kv_heads, d)
- key_cache: (num_blocks, n_kv_heads, block_size, d)
- value_cache: (num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size)
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- l_m: shape (1, n_heads, seq_q, 2)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
@ -306,11 +457,8 @@ def flash_paged_attention(
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
with Performance config parameters for flash attention with default
values
seq_tile_size: `default=2048`, size of the kv tile size for attention
computation reduction
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
@ -322,31 +470,65 @@ def flash_paged_attention(
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
config = config or FlashConfig()
B_F_SIZE = 512
B_P_SIZE = 128
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
LARGE_TILE_SZ = config.seq_tile_size
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
num_blocks, block_size, k_h, _ = key_cache.shape
num_blocks, k_h, block_size, _ = key_cache.shape
q_h_per_k_h = h // k_h
assert tuple(key_cache.shape) == (
num_blocks,
block_size,
k_h,
d,
), "Input shape mismatch!"
assert tuple(value_cache.shape) == (
num_blocks,
block_size,
k_h,
d,
), "Input shape mismatch!"
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (num_blocks, k_h, block_size, d)
assert (tuple(key_cache.shape) == cache_shape
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
assert (tuple(value_cache.shape) == cache_shape
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == (
1,
k_h,
d,
seqlen_q,
), f"key shape {key.shape} mismatch!"
assert value is None or tuple(value.shape) == (
1,
k_h,
seqlen_q,
d,
), f"value shape {value.shape} mismatch!"
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (
LARGE_TILE_SZ % B_F_SIZE == 0
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert is_power_of_2(
num_blocks_per_large_tile
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
if seqlen_q > B_F_SIZE:
MAX_REDUCTION_TILE = 2048
if seqlen_q // 2 > MAX_REDUCTION_TILE:
assert (
seqlen_q % MAX_REDUCTION_TILE == 0
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
else:
assert (seqlen_q % B_F_SIZE == 0
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
softmax_scale = softmax_scale or (1.0 / (d**0.5))
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
o = nl.ndarray((b, h, seqlen_q, d),
dtype=query.dtype,
@ -373,35 +555,38 @@ def flash_paged_attention(
buffer=nl.sbuf,
lazy_initialization=True,
)
block_tables_sbuf = load_block_tables(
block_tables_hbm=block_tables,
num_tiles=num_large_k_tile,
num_blocks_per_tile=num_blocks_per_large_tile,
)
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
if num_blocks_per_large_tile < B_P_SIZE:
# we checked num_blocks_per_tile is a power of 2
assert B_P_SIZE % num_blocks_per_large_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
# We assume block_size >= block_size_tiling_factor
assert block_size % block_size_tiling_factor == 0
else:
block_size_tiling_factor = 1
tiled_block_size = block_size // block_size_tiling_factor
softmax_scale = softmax_scale or (1.0 / (d**0.5))
# Indirect DMA load must be placed along Partition Dimension
block_tables_sbuf = transform_block_tables_for_indirect_load(
block_tables_sbuf,
block_size_tiling_factor=block_size_tiling_factor,
num_head=k_h,
head_id=head_id,
)
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (config.seq_tile_size >= 512
), f" seq tile_size {config.seq_tile_size} cannot be less than 512"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
assert (
LARGE_TILE_SZ % B_P_SIZE == 0
), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}"
assert (B_P_SIZE % block_size == 0
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
assert is_power_of_2(
num_blocks_per_large_tile
), "The number of blocks in each large tile is expected of be power of 2"
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
# Flatten KV cache to be 2D for loading into SBUF
new_cache_shape = (
num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d,
)
key_cache = key_cache.reshape(new_cache_shape)
value_cache = value_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
@ -411,7 +596,7 @@ def flash_paged_attention(
lazy_initialization=True,
)
l_buffer = nl.zeros(
(par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h),
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
@ -423,50 +608,42 @@ def flash_paged_attention(
lazy_initialization=True,
)
for j in nl.sequential_range(0, num_large_k_tile):
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
)
for k_i in nl.affine_range(num_blocks_per_large_tile):
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
head_id, :])
cur_k_tile[:, nl.ds(k_i *
block_size, block_size)] = nl.transpose(loaded)
load_tile_size = B_P_SIZE
num_blocks_per_partition = load_tile_size // block_size
for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
for block_in_partition in nl.affine_range(
num_blocks_per_partition):
v_i = (partition_idx * num_blocks_per_partition +
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
head_id, :])
cur_v_tile[
partition_idx,
nl.ds(block_in_partition * block_size, block_size),
:,
] = loaded_v
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
dtype=kernel_dtype,
)
load_kv_tile_from_cache(
cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile,
tiled_block_size=tiled_block_size,
B_P_SIZE=B_P_SIZE,
B_D_SIZE=B_D_SIZE,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
])
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
dtype=kernel_dtype,
) # load (d, 128) tile in SBUF
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
@ -474,15 +651,15 @@ def flash_paged_attention(
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
q_tile_idx=i,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=config,
use_causal_mask=False,
tile_mask=cur_mask,
initialize=j == 0,
use_causal_mask=False,
q_tile_idx=i,
initialize=large_k_tile_idx == 0,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
@ -492,62 +669,58 @@ def flash_paged_attention(
if key is not None and value is not None:
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
active_config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
should_transpose_v=config.should_transpose_v,
)
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
dtype=kernel_dtype,
)
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :])
loaded = nl.load(key[batch_id, head_id, :, :])
if loaded.dtype != kernel_dtype:
loaded = nl.copy(loaded, dtype=kernel_dtype)
cur_k_tile[:, :] = loaded
load_tile_size = B_P_SIZE
v_hbm_tile = value[batch_id, head_id]
for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
load_v_tile(
v_hbm_tile=v_hbm_tile,
cur_v_tile=cur_v_tile,
j=0,
large_tile_idx=0,
v_i=v_i,
config=active_config,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(
mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
],
dtype=mask.dtype,
)
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
dtype=kernel_dtype,
) # load (d, 128) tile in SBUF
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
q_tile_idx=i,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=active_config,
use_causal_mask=True,
tile_mask=cur_mask,
use_causal_mask=True,
q_tile_idx=i,
initialize=False,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
@ -559,8 +732,8 @@ def flash_paged_attention(
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
out = nl.multiply(
o_buffer[i, i_q_h, :, :],
nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]),
o_buffer[i, i_q_h],
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
dtype=kernel_dtype,
)
@ -589,7 +762,7 @@ def flash_paged_attention(
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[:, i, i_q_h],
l_buffer[i, i_q_h],
)
nl.store(
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
@ -601,6 +774,49 @@ def flash_paged_attention(
return o
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
"""
Reorder the mask to make it compatible with the flash attention kernel.
We vectorize KV cache read to improve DMA utilization. However, the layout
that maximizes DMA bandwidth changes the order tokens are consumed.
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
each step the engine consumes a column (rather than a row) of B_P_SIZE
tokens. Therefore, the tokens are visited in a strided way.
To make sure mask matches the order tokens are consumed, we need to properly
transpose mask.
"""
total_query_len, total_seq_len = mask.shape
context_kv_len = total_seq_len - total_query_len
B_P_SIZE = 128
assert (LARGE_TILE_SZ
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
if tiled_block_size > 1:
# Mask reordering is needed when tiled_block_size > 1
device = mask.device
mask = mask.cpu()
context_mask = mask[:, :context_kv_len]
context_mask = context_mask.view(
total_query_len,
context_kv_len // LARGE_TILE_SZ,
num_tiled_blocks // B_P_SIZE,
B_P_SIZE,
tiled_block_size,
)
context_mask = context_mask.transpose(3, 4).reshape(
total_query_len, context_kv_len)
new_mask = mask[:, context_kv_len:]
return torch.concat([context_mask, new_mask], dim=1).to(device)
else:
return mask
def flash_attn_varlen_nkifunc(
query,
key,
@ -612,13 +828,32 @@ def flash_attn_varlen_nkifunc(
n_kv_head=None,
head_size=None,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
mixed_precision=True,
):
config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
should_transpose_v=False,
)
"""
Compute flash paged attention for variable length sequences.
This function is a wrapper around the flash attention NKI kernel. It takes
in the following arguments:
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- key_cache: (n_blocks, n_kv_heads, block_size, d)
- value_cache: (n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
Notes:
- attn_mask must be reordered outside using `reorder_context_mask`
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
for better DMA throughput
"""
if n_kv_head is None:
n_kv_head = key_cache.shape[1]
assert key_cache.shape[1] == n_kv_head
if head_size is None:
head_size = key_cache.shape[-1]
kwargs = dict(
query=query,
key=key,
@ -628,15 +863,9 @@ def flash_attn_varlen_nkifunc(
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
config=config,
mixed_precision=mixed_precision,
return_debug_tensors=return_debug_tensors,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
_, n_kv_head, _, _ = key.shape
if return_debug_tensors:
o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs)
return o, *debug_tensors
else:
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o