mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 16:25:18 +08:00
[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245)
Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
parent
71face8540
commit
33170081f1
153
tests/neuron/test_block_table.py
Normal file
153
tests/neuron/test_block_table.py
Normal file
@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
import neuronxcc.nki.language as nl
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from neuronxcc import nki
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import (
|
||||
load_block_tables, transform_block_tables_for_indirect_load)
|
||||
|
||||
|
||||
def is_power_of_2(n):
|
||||
return n > 0 and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def nki_load_and_transform_block_tables(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
num_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
):
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
|
||||
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
|
||||
num_blocks_per_tile)
|
||||
|
||||
# we need to pass an Index as head_id
|
||||
head_id = nl.arange(1)[None, :] + head_id
|
||||
|
||||
block_tables_transposed = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
|
||||
B_P_SIZE = 128
|
||||
assert block_tables_transposed.shape[1] == B_P_SIZE
|
||||
|
||||
out = nl.ndarray(
|
||||
block_tables_transposed.shape,
|
||||
dtype=nl.int32,
|
||||
buffer=nl.shared_hbm,
|
||||
)
|
||||
for i in nl.affine_range(block_tables_transposed.shape[0]):
|
||||
nl.store(dst=out[i], value=block_tables_transposed[i])
|
||||
return out
|
||||
|
||||
|
||||
def ref_block_tables_transform(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
num_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
):
|
||||
assert block_tables.numel() == num_tiles * num_blocks_per_tile
|
||||
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
|
||||
B_F_SIZE = 128
|
||||
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
|
||||
block_tables = F.pad(
|
||||
block_tables,
|
||||
(0, 0, 0, num_tiles_padded - num_tiles),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
block_tables = block_tables * num_head + head_id
|
||||
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
|
||||
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
|
||||
block_tables = block_tables * block_size_tiling_factor + offset
|
||||
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
|
||||
|
||||
num_blocks_per_tile = block_tables_transposed.shape[0]
|
||||
assert num_blocks_per_tile % B_F_SIZE == 0
|
||||
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
|
||||
B_F_SIZE, num_tiles_padded)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"q_head_per_kv_head,head_id",
|
||||
[
|
||||
(1, 0),
|
||||
(3, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tiles,num_blocks_per_tile",
|
||||
[
|
||||
(1, 1),
|
||||
(13, 16),
|
||||
(17, 128),
|
||||
(35, 512),
|
||||
(128, 128),
|
||||
(130, 64),
|
||||
(280, 256),
|
||||
(315, 1),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_load_and_transform_block_tables(
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
) -> None:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags = [
|
||||
"-O1",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
torch.manual_seed(10000)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
B_P_SIZE = 128
|
||||
if num_blocks_per_tile < B_P_SIZE:
|
||||
assert B_P_SIZE % num_blocks_per_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
max_num_blocks = 100000
|
||||
block_tables = torch.randint(
|
||||
0,
|
||||
max_num_blocks,
|
||||
(num_tiles * num_blocks_per_tile, ),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
|
||||
block_tables.to(device=device),
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
).cpu()
|
||||
ref_out = ref_block_tables_transform(
|
||||
block_tables,
|
||||
num_tiles,
|
||||
num_blocks_per_tile,
|
||||
q_head_per_kv_head,
|
||||
head_id,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
assert (nki_out.shape == ref_out.shape
|
||||
), f"{nki_out.shape=} != {ref_out.shape=}"
|
||||
assert torch.all(nki_out == ref_out)
|
||||
@ -107,7 +107,7 @@ def ref_masked_attention(
|
||||
masked_score, dim=-1, return_max_reduce=True)
|
||||
else:
|
||||
norm_score = ref_softmax(masked_score, dim=-1)
|
||||
out = torch.einsum("hqk,khd->qhd", norm_score, value)
|
||||
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
|
||||
if return_max_reduce:
|
||||
return (
|
||||
out,
|
||||
@ -118,7 +118,7 @@ def ref_masked_attention(
|
||||
scaled_qk,
|
||||
)
|
||||
else:
|
||||
return out
|
||||
return (out, )
|
||||
|
||||
|
||||
def ref_context_attention(
|
||||
@ -128,8 +128,6 @@ def ref_context_attention(
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=False,
|
||||
):
|
||||
@ -146,18 +144,19 @@ def ref_context_attention(
|
||||
attn_mask = torch.logical_not(attn_mask)
|
||||
attn_mask = attn_mask.float() * -30000
|
||||
|
||||
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||
ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
attn_mask,
|
||||
return_max_reduce=return_max_reduce,
|
||||
))
|
||||
output, *debug_tensors = ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
attn_mask,
|
||||
return_max_reduce=return_max_reduce,
|
||||
)
|
||||
|
||||
output = output.unsqueeze(1)
|
||||
if return_max_reduce:
|
||||
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||
debug_tensors)
|
||||
return (
|
||||
output,
|
||||
cached_max,
|
||||
@ -170,65 +169,22 @@ def ref_context_attention(
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size, large_tile_size",
|
||||
[
|
||||
(32, 2048), # 64 blocks
|
||||
(32, 4096), # 128 blocks
|
||||
(32, 8192), # 256 blocks
|
||||
(64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||
[
|
||||
(4, 2, 8, False),
|
||||
(4, 2, 8, True),
|
||||
(32, 8, 64, True),
|
||||
(16, 2, 128, True),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags = [
|
||||
"--model-type=transformer -O1",
|
||||
"--internal-hlo2tensorizer-options='--verify-hlo'",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
prefill_batch_size = 4
|
||||
decode_batch_size = 12
|
||||
def sample_inputs(
|
||||
prefill_batch_size,
|
||||
decode_batch_size,
|
||||
min_query_len,
|
||||
max_query_len,
|
||||
min_ctx_len,
|
||||
max_ctx_len,
|
||||
block_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
):
|
||||
batch_size = prefill_batch_size + decode_batch_size
|
||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||
|
||||
max_block_per_request = max_model_len // block_size
|
||||
dtype = torch.float32
|
||||
cache_size = (batch_size * max_block_per_request) + 2
|
||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (prefill_batch_size, ),
|
||||
@ -244,7 +200,6 @@ def test_contexted_kv_attention(
|
||||
dtype=torch.long,
|
||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
@ -304,47 +259,139 @@ def test_contexted_kv_attention(
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
|
||||
return (
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
)
|
||||
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
context_lens = seq_lens - query_lens
|
||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
num_seqs = len(seq_lens)
|
||||
active_blocks: list[int] = []
|
||||
for seq_id in range(num_seqs):
|
||||
active_blocks = (
|
||||
active_blocks +
|
||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||
return F.pad(
|
||||
torch.tensor(active_blocks, dtype=torch.int32),
|
||||
(0, num_blocks - len(active_blocks)),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_batch_size,decode_batch_size,block_size,large_tile_size",
|
||||
[
|
||||
(1, 199, 1, 512), # 512 blocks
|
||||
(4, 12, 256, 2048), # 128 blocks
|
||||
(4, 12, 16, 2048), # 128 blocks
|
||||
(4, 12, 4, 1024), # 256 blocks
|
||||
(4, 12, 32, 2048), # 64 blocks
|
||||
(4, 12, 32, 4096), # 128 blocks
|
||||
(4, 12, 32, 8192), # 256 blocks
|
||||
(4, 12, 64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size",
|
||||
[
|
||||
(4, 2, 8),
|
||||
(32, 8, 64),
|
||||
(4, 4, 128),
|
||||
(8, 1, 32),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("mixed_precision", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
prefill_batch_size: int,
|
||||
decode_batch_size: int,
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
|
||||
reorder_context_mask)
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags = [
|
||||
"-O1",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
dtype = torch.float32
|
||||
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
(
|
||||
output_ref,
|
||||
cached_max,
|
||||
cached_sum_reciprocal,
|
||||
lse,
|
||||
masked_score,
|
||||
scaled_qk,
|
||||
) = ref_context_attention(
|
||||
query,
|
||||
k_active,
|
||||
v_active,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
) = sample_inputs(
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
decode_batch_size=decode_batch_size,
|
||||
min_query_len=min_query_len,
|
||||
max_query_len=max_query_len,
|
||||
min_ctx_len=min_ctx_len,
|
||||
max_ctx_len=max_ctx_len,
|
||||
block_size=block_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
output_ref = ref_context_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=True,
|
||||
return_max_reduce=False,
|
||||
)
|
||||
|
||||
# build neuron program
|
||||
return_debug_tensors = False
|
||||
B_P_SIZE = 128
|
||||
LARGE_TILE_SZ = large_tile_size
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
context_lens = seq_lens - query_lens
|
||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
num_seqs = len(seq_lens)
|
||||
active_blocks: list[int] = []
|
||||
for seq_id in range(num_seqs):
|
||||
active_blocks = (
|
||||
active_blocks +
|
||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||
return F.pad(
|
||||
torch.tensor(active_blocks),
|
||||
(0, num_blocks - len(active_blocks)),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
assert (large_tile_size >= B_P_SIZE
|
||||
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
@ -357,32 +404,27 @@ def test_contexted_kv_attention(
|
||||
return 2**int(a - 1).bit_length()
|
||||
|
||||
# calculate input shapes
|
||||
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
|
||||
max_num_queries = pad_to_next_power_of_2(max_num_queries)
|
||||
head_size_padded = B_P_SIZE
|
||||
assert head_size_padded >= head_size
|
||||
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
LARGE_TILE_SZ // block_size)
|
||||
large_tile_size // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (context_kv_len %
|
||||
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
|
||||
# pad QKV tensors
|
||||
pad_dims = (
|
||||
0,
|
||||
head_size_padded - query.shape[2],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
max_num_queries - query.shape[0],
|
||||
)
|
||||
query = F.pad(query, pad_dims, "constant", 0)
|
||||
k = F.pad(k, pad_dims, "constant", 0)
|
||||
v = F.pad(v, pad_dims, "constant", 0)
|
||||
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||
k = F.pad(k_active, pad_dims, "constant", 0)
|
||||
v = F.pad(v_active, pad_dims, "constant", 0)
|
||||
|
||||
# permute QKV tensors
|
||||
# query: (1, n_heads, d, seq_q)
|
||||
@ -391,6 +433,8 @@ def test_contexted_kv_attention(
|
||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
|
||||
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# transform block table
|
||||
active_block_table = get_active_block_tables(
|
||||
@ -405,33 +449,31 @@ def test_contexted_kv_attention(
|
||||
prior_mask, active_mask = (
|
||||
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens, block_size=block_size))
|
||||
attn_mask = torch.concat(
|
||||
[
|
||||
F.pad(
|
||||
prior_mask,
|
||||
(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool(),
|
||||
F.pad(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool(),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
prior_mask_padded = F.pad(
|
||||
prior_mask,
|
||||
(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
active_mask_padded = F.pad(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)
|
||||
|
||||
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size)
|
||||
|
||||
input_args = (
|
||||
query.to(device=device),
|
||||
@ -439,29 +481,21 @@ def test_contexted_kv_attention(
|
||||
v.to(device=device),
|
||||
k_cache.to(device=device),
|
||||
v_cache.to(device=device),
|
||||
active_block_table.to(torch.int32).to(device=device),
|
||||
active_block_table.to(device=device),
|
||||
attn_mask.to(device=device),
|
||||
)
|
||||
input_kwargs = dict(
|
||||
n_kv_head=num_kv_heads,
|
||||
head_size=head_size,
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
return_debug_tensors=return_debug_tensors,
|
||||
LARGE_TILE_SZ=large_tile_size,
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
|
||||
*input_args, **input_kwargs)
|
||||
else:
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
debug_tensors = []
|
||||
|
||||
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
|
||||
num_actual_tokens = sum(query_lens)
|
||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
|
||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||
output_ref_padded = F.pad(
|
||||
output_ref,
|
||||
|
||||
@ -1,27 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import neuronxcc.nki.isa as nisa
|
||||
import neuronxcc.nki.language as nl
|
||||
import numpy as np
|
||||
import torch
|
||||
from neuronxcc import nki
|
||||
from neuronxcc.nki.language import par_dim
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FlashConfig:
|
||||
"""
|
||||
Config class for flash attention with default values
|
||||
"""
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
seq_tile_size: int = 2048
|
||||
should_transpose_v: bool = False
|
||||
|
||||
__annotations__ = {
|
||||
"seq_tile_size": int,
|
||||
"should_transpose_v": bool,
|
||||
}
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||
"""
|
||||
Load block tables from HBM into SRAM
|
||||
|
||||
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
||||
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
|
||||
# reshape as `(num_tiles, num_blocks_per_tile)`
|
||||
assert len(block_tables_hbm.shape) == 1
|
||||
(num_total_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
|
||||
block_tables_sbuf = nl.zeros(
|
||||
(ceil_div(num_tiles,
|
||||
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
||||
dtype=nl.int32,
|
||||
mask=(i_p + i * B_P_SIZE < num_tiles),
|
||||
)
|
||||
return block_tables_sbuf
|
||||
|
||||
|
||||
@nki.jit
|
||||
def transform_block_tables_for_indirect_load(
|
||||
block_tables,
|
||||
block_size_tiling_factor,
|
||||
num_head,
|
||||
head_id,
|
||||
):
|
||||
"""
|
||||
This function does two things:
|
||||
1. calculate new `block_tables` for a `head_id` after flattening
|
||||
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
||||
2. transpose the result so that `block_table` for each tile is mapped to
|
||||
SBUF Partition dimension for vectorized DMA
|
||||
|
||||
Tiling trick to further improve DMA performance:
|
||||
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
||||
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
||||
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
||||
fully utilize hardware parallelization. The solution is to tile `block_size`
|
||||
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
||||
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
||||
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
||||
|
||||
Note:
|
||||
We don't further tile D dimension as small DMA size also hurts performance.
|
||||
"""
|
||||
B_P_SIZE = 128
|
||||
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
||||
block_tables.shape)
|
||||
assert num_tiles_per_partition == B_P_SIZE
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||
|
||||
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
||||
block_tables_transposed = nl.ndarray(
|
||||
(
|
||||
num_loads,
|
||||
par_dim(B_P_SIZE),
|
||||
num_partitions * num_tiles_per_partition,
|
||||
),
|
||||
dtype=nl.int32,
|
||||
)
|
||||
|
||||
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
||||
if num_head > 1:
|
||||
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
||||
head_id = nl.transpose(
|
||||
head_id.broadcast_to((1, num_tiles_per_partition)))
|
||||
if num_blocks_per_tile > 1:
|
||||
head_id = head_id.broadcast_to(
|
||||
(num_tiles_per_partition, num_blocks_per_tile))
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
broadcast_shape = (
|
||||
num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
block_size_tiling_factor,
|
||||
)
|
||||
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
||||
dtype=nl.int32).broadcast_to(broadcast_shape)
|
||||
|
||||
for partition_id in nl.affine_range(num_partitions):
|
||||
block_tables_partition = block_tables[partition_id]
|
||||
if num_head > 1:
|
||||
# fuse num_block and num_head dimension
|
||||
block_tables_partition = block_tables_partition * num_head + head_id
|
||||
|
||||
if block_size_tiling_factor > 1:
|
||||
# need to apply block size tiling trick
|
||||
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
||||
block_tables_partition = ((block_tables_partition *
|
||||
block_size_tiling_factor).reshape(
|
||||
(num_tiles_per_partition,
|
||||
num_blocks_per_tile,
|
||||
1)).broadcast_to(broadcast_shape))
|
||||
new_block_tables = block_tables_partition + offset
|
||||
new_block_tables = new_block_tables.reshape(
|
||||
(num_tiles_per_partition, B_P_SIZE))
|
||||
else:
|
||||
new_block_tables = block_tables_partition
|
||||
|
||||
# transpose the block table so that it can be used by vector DGE
|
||||
for i in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = (partition_id * num_tiles_per_partition +
|
||||
nl.arange(num_tiles_per_partition)[None, :])
|
||||
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
||||
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
||||
return block_tables_transposed
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_kv_tile_from_cache(
|
||||
cur_k_tile,
|
||||
cur_v_tile,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
large_k_tile_idx,
|
||||
num_blocks_per_large_tile,
|
||||
tiled_block_size,
|
||||
B_P_SIZE,
|
||||
B_D_SIZE,
|
||||
):
|
||||
"""
|
||||
Load KV cache and transform Key and Value into layout required by Matmul
|
||||
|
||||
Vectorized DMA Load layout:
|
||||
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
|
||||
Layout used by attention matmuls:
|
||||
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
||||
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
||||
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||
"""
|
||||
# load key cache
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_k_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||
# Transpose SBUF tensor using PE
|
||||
for tb_i in nl.affine_range(tiled_block_size):
|
||||
cur_k_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
),
|
||||
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
||||
|
||||
# load value cache
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
cur_v_tile[
|
||||
:,
|
||||
nl.ds(
|
||||
load_idx * tiled_block_size * B_D_SIZE,
|
||||
tiled_block_size * B_D_SIZE,
|
||||
),
|
||||
] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -62,13 +238,13 @@ def _flash_attention_core(
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
q_tile_idx,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
flash_config: FlashConfig,
|
||||
use_causal_mask,
|
||||
tile_mask,
|
||||
use_causal_mask,
|
||||
q_tile_idx=None,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=2048,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
@ -77,19 +253,19 @@ def _flash_attention_core(
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||
already. The block size of K and V
|
||||
is defined in the seq_tile_size of the flash_config. The results are stored
|
||||
in the following three buffers
|
||||
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
||||
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
||||
be split into size B_F_SIZE tiles
|
||||
|
||||
The results are stored in the following three buffers
|
||||
o_buffer: (B_P_SIZE, d)
|
||||
l_buffer: (B_P_SIZE, 1)
|
||||
m_buffer: (B_P_SIZE, 1)
|
||||
|
||||
All IO buffers are in SBUF.
|
||||
"""
|
||||
LARGE_TILE_SZ = flash_config.seq_tile_size
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
|
||||
# mask are used to only apply computation to the lower half of the matrix,
|
||||
# which reduce the arithmetic intensity by half
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
@ -99,6 +275,8 @@ def _flash_attention_core(
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
if use_causal_mask:
|
||||
# mask are used to only apply computation to the lower half of the
|
||||
# matrix, which reduce the arithmetic intensity by up to 50%
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
@ -165,7 +343,9 @@ def _flash_attention_core(
|
||||
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||
|
||||
p_partial_sum = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
||||
dtype=acc_type,
|
||||
)
|
||||
|
||||
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||
@ -194,13 +374,15 @@ def _flash_attention_core(
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum)
|
||||
pv_psum = nl.zeros(
|
||||
(par_dim(B_P_SIZE), B_D_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum,
|
||||
)
|
||||
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
pv_psum[:, :] += nl.matmul(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[k_i, :, :],
|
||||
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
||||
transpose_x=True,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
@ -219,44 +401,16 @@ def _flash_attention_core(
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
|
||||
LARGE_TILE_SZ = config.seq_tile_size
|
||||
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
||||
B_P_SIZE = 128
|
||||
|
||||
if not config.should_transpose_v:
|
||||
cur_v_tile[v_i, :, :] = nl.load(
|
||||
v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
|
||||
dtype=cur_v_tile.dtype,
|
||||
)
|
||||
return
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
cur_v_tile_transposed = nisa.dma_transpose(
|
||||
v_hbm_tile[:,
|
||||
nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
|
||||
cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
|
||||
dtype=cur_v_tile.dtype)
|
||||
return
|
||||
|
||||
cur_v_tile[v_i, :, :] = nl.load_transpose2d(
|
||||
v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
|
||||
dtype=cur_v_tile.dtype,
|
||||
)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles):
|
||||
(num_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks % num_tiles == 0
|
||||
num_blocks_per_tile = num_blocks // num_tiles
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
|
||||
return block_tables_buffer
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
B_D_SIZE = v_hbm_tile.shape[-1]
|
||||
loaded = nl.load(v_hbm_tile[
|
||||
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
||||
:,
|
||||
])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -270,24 +424,21 @@ def flash_paged_attention(
|
||||
mask,
|
||||
softmax_scale=None,
|
||||
mixed_precision=True,
|
||||
config=None,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
):
|
||||
"""
|
||||
Flash PagedAttention Forward Kernel.
|
||||
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
|
||||
- Chunked Prefill Paper: https://arxiv.org/abs/2403.02310
|
||||
|
||||
IO tensor layouts:
|
||||
- query: shape (1, n_heads, d, seq_q)
|
||||
- key: shape (1, n_kv_heads, d, seq_k)
|
||||
- value: shape (1, n_kv_heads, seq_v, d)
|
||||
- key_cache: (num_blocks, block_size, n_kv_heads, d)
|
||||
- value_cache: (num_blocks, block_size, n_kv_heads, d)
|
||||
- key_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||
- value_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (num_active_blocks, )
|
||||
- mask: (seq_q, num_active_blocks * block_size)
|
||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||
- o: shape (1, n_heads, seq_q, d)
|
||||
- l_m: shape (1, n_heads, seq_q, 2)
|
||||
|
||||
- This kernel requires seq_k == seq_v
|
||||
- We use continuous batching by default, so the batch dimension is
|
||||
@ -306,11 +457,8 @@ def flash_paged_attention(
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
|
||||
with Performance config parameters for flash attention with default
|
||||
values
|
||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
@ -322,31 +470,65 @@ def flash_paged_attention(
|
||||
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||
"""
|
||||
config = config or FlashConfig()
|
||||
B_F_SIZE = 512
|
||||
B_P_SIZE = 128
|
||||
b, h, d, seqlen_q = query.shape
|
||||
B_D_SIZE = d
|
||||
LARGE_TILE_SZ = config.seq_tile_size
|
||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||
num_blocks, block_size, k_h, _ = key_cache.shape
|
||||
num_blocks, k_h, block_size, _ = key_cache.shape
|
||||
q_h_per_k_h = h // k_h
|
||||
assert tuple(key_cache.shape) == (
|
||||
num_blocks,
|
||||
block_size,
|
||||
k_h,
|
||||
d,
|
||||
), "Input shape mismatch!"
|
||||
assert tuple(value_cache.shape) == (
|
||||
num_blocks,
|
||||
block_size,
|
||||
k_h,
|
||||
d,
|
||||
), "Input shape mismatch!"
|
||||
assert b == 1, f"invalid batch size {b=}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||
cache_shape = (num_blocks, k_h, block_size, d)
|
||||
assert (tuple(key_cache.shape) == cache_shape
|
||||
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert (tuple(value_cache.shape) == cache_shape
|
||||
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert key is None or tuple(key.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
d,
|
||||
seqlen_q,
|
||||
), f"key shape {key.shape} mismatch!"
|
||||
assert value is None or tuple(value.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
seqlen_q,
|
||||
d,
|
||||
), f"value shape {value.shape} mismatch!"
|
||||
|
||||
assert (
|
||||
nl.program_ndim() == 2
|
||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||
batch_id = nl.program_id(axis=0)
|
||||
head_id = nl.program_id(axis=1)
|
||||
|
||||
(num_active_blocks, ) = block_tables.shape
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (
|
||||
LARGE_TILE_SZ % B_F_SIZE == 0
|
||||
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
||||
if seqlen_q > B_F_SIZE:
|
||||
MAX_REDUCTION_TILE = 2048
|
||||
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
||||
assert (
|
||||
seqlen_q % MAX_REDUCTION_TILE == 0
|
||||
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
||||
else:
|
||||
assert (seqlen_q % B_F_SIZE == 0
|
||||
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
||||
|
||||
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
|
||||
o = nl.ndarray((b, h, seqlen_q, d),
|
||||
dtype=query.dtype,
|
||||
@ -373,35 +555,38 @@ def flash_paged_attention(
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
)
|
||||
block_tables_sbuf = load_block_tables(
|
||||
block_tables_hbm=block_tables,
|
||||
num_tiles=num_large_k_tile,
|
||||
num_blocks_per_tile=num_blocks_per_large_tile,
|
||||
)
|
||||
|
||||
assert (
|
||||
nl.program_ndim() == 2
|
||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||
batch_id = nl.program_id(axis=0)
|
||||
head_id = nl.program_id(axis=1)
|
||||
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||
if num_blocks_per_large_tile < B_P_SIZE:
|
||||
# we checked num_blocks_per_tile is a power of 2
|
||||
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
||||
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
||||
# We assume block_size >= block_size_tiling_factor
|
||||
assert block_size % block_size_tiling_factor == 0
|
||||
else:
|
||||
block_size_tiling_factor = 1
|
||||
tiled_block_size = block_size // block_size_tiling_factor
|
||||
|
||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||
# Indirect DMA load must be placed along Partition Dimension
|
||||
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
||||
block_tables_sbuf,
|
||||
block_size_tiling_factor=block_size_tiling_factor,
|
||||
num_head=k_h,
|
||||
head_id=head_id,
|
||||
)
|
||||
|
||||
(num_active_blocks, ) = block_tables.shape
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (config.seq_tile_size >= 512
|
||||
), f" seq tile_size {config.seq_tile_size} cannot be less than 512"
|
||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||
assert (
|
||||
LARGE_TILE_SZ % B_P_SIZE == 0
|
||||
), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}"
|
||||
assert (B_P_SIZE % block_size == 0
|
||||
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), "The number of blocks in each large tile is expected of be power of 2"
|
||||
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
|
||||
|
||||
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
|
||||
# Flatten KV cache to be 2D for loading into SBUF
|
||||
new_cache_shape = (
|
||||
num_blocks * k_h * block_size_tiling_factor,
|
||||
tiled_block_size * d,
|
||||
)
|
||||
key_cache = key_cache.reshape(new_cache_shape)
|
||||
value_cache = value_cache.reshape(new_cache_shape)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
@ -411,7 +596,7 @@ def flash_paged_attention(
|
||||
lazy_initialization=True,
|
||||
)
|
||||
l_buffer = nl.zeros(
|
||||
(par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h),
|
||||
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||
dtype=acc_type,
|
||||
buffer=nl.sbuf,
|
||||
lazy_initialization=True,
|
||||
@ -423,50 +608,42 @@ def flash_paged_attention(
|
||||
lazy_initialization=True,
|
||||
)
|
||||
|
||||
for j in nl.sequential_range(0, num_large_k_tile):
|
||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
||||
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||
cur_k_tile = nl.ndarray(
|
||||
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
for k_i in nl.affine_range(num_blocks_per_large_tile):
|
||||
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
|
||||
head_id, :])
|
||||
cur_k_tile[:, nl.ds(k_i *
|
||||
block_size, block_size)] = nl.transpose(loaded)
|
||||
|
||||
load_tile_size = B_P_SIZE
|
||||
num_blocks_per_partition = load_tile_size // block_size
|
||||
for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
||||
for block_in_partition in nl.affine_range(
|
||||
num_blocks_per_partition):
|
||||
v_i = (partition_idx * num_blocks_per_partition +
|
||||
block_in_partition)
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
|
||||
head_id, :])
|
||||
cur_v_tile[
|
||||
partition_idx,
|
||||
nl.ds(block_in_partition * block_size, block_size),
|
||||
:,
|
||||
] = loaded_v
|
||||
cur_v_tile = nl.ndarray(
|
||||
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
load_kv_tile_from_cache(
|
||||
cur_k_tile=cur_k_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_tables=block_tables_sbuf,
|
||||
large_k_tile_idx=large_k_tile_idx,
|
||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||
tiled_block_size=tiled_block_size,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
|
||||
])
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
||||
dtype=kernel_dtype,
|
||||
) # load (d, 128) tile in SBUF
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
|
||||
_flash_attention_core(
|
||||
@ -474,15 +651,15 @@ def flash_paged_attention(
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
q_tile_idx=i,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=config,
|
||||
use_causal_mask=False,
|
||||
tile_mask=cur_mask,
|
||||
initialize=j == 0,
|
||||
use_causal_mask=False,
|
||||
q_tile_idx=i,
|
||||
initialize=large_k_tile_idx == 0,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
@ -492,62 +669,58 @@ def flash_paged_attention(
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
active_config = FlashConfig(
|
||||
seq_tile_size=LARGE_TILE_SZ,
|
||||
should_transpose_v=config.should_transpose_v,
|
||||
)
|
||||
|
||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
cur_v_tile = nl.ndarray(
|
||||
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :])
|
||||
loaded = nl.load(key[batch_id, head_id, :, :])
|
||||
if loaded.dtype != kernel_dtype:
|
||||
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
||||
cur_k_tile[:, :] = loaded
|
||||
|
||||
load_tile_size = B_P_SIZE
|
||||
v_hbm_tile = value[batch_id, head_id]
|
||||
for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
||||
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||
load_v_tile(
|
||||
v_hbm_tile=v_hbm_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
j=0,
|
||||
large_tile_idx=0,
|
||||
v_i=v_i,
|
||||
config=active_config,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(
|
||||
mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
],
|
||||
dtype=mask.dtype,
|
||||
)
|
||||
cur_mask = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
||||
dtype=kernel_dtype,
|
||||
) # load (d, 128) tile in SBUF
|
||||
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||
nl.ds(i *
|
||||
B_P_SIZE, B_P_SIZE)])
|
||||
if q_sbuf_tile.dtype != kernel_dtype:
|
||||
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||
_flash_attention_core(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
l_buffer=l_buffer[i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
q_tile_idx=i,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=active_config,
|
||||
use_causal_mask=True,
|
||||
tile_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
q_tile_idx=i,
|
||||
initialize=False,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
@ -559,8 +732,8 @@ def flash_paged_attention(
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
out = nl.multiply(
|
||||
o_buffer[i, i_q_h, :, :],
|
||||
nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]),
|
||||
o_buffer[i, i_q_h],
|
||||
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
||||
dtype=kernel_dtype,
|
||||
)
|
||||
|
||||
@ -589,7 +762,7 @@ def flash_paged_attention(
|
||||
head_id * q_h_per_k_h + i_q_h,
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
],
|
||||
l_buffer[:, i, i_q_h],
|
||||
l_buffer[i, i_q_h],
|
||||
)
|
||||
nl.store(
|
||||
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||
@ -601,6 +774,49 @@ def flash_paged_attention(
|
||||
return o
|
||||
|
||||
|
||||
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
||||
"""
|
||||
Reorder the mask to make it compatible with the flash attention kernel.
|
||||
|
||||
We vectorize KV cache read to improve DMA utilization. However, the layout
|
||||
that maximizes DMA bandwidth changes the order tokens are consumed.
|
||||
|
||||
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
||||
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
||||
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
||||
tokens. Therefore, the tokens are visited in a strided way.
|
||||
|
||||
To make sure mask matches the order tokens are consumed, we need to properly
|
||||
transpose mask.
|
||||
"""
|
||||
total_query_len, total_seq_len = mask.shape
|
||||
context_kv_len = total_seq_len - total_query_len
|
||||
|
||||
B_P_SIZE = 128
|
||||
assert (LARGE_TILE_SZ
|
||||
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
||||
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
||||
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
||||
if tiled_block_size > 1:
|
||||
# Mask reordering is needed when tiled_block_size > 1
|
||||
device = mask.device
|
||||
mask = mask.cpu()
|
||||
context_mask = mask[:, :context_kv_len]
|
||||
context_mask = context_mask.view(
|
||||
total_query_len,
|
||||
context_kv_len // LARGE_TILE_SZ,
|
||||
num_tiled_blocks // B_P_SIZE,
|
||||
B_P_SIZE,
|
||||
tiled_block_size,
|
||||
)
|
||||
context_mask = context_mask.transpose(3, 4).reshape(
|
||||
total_query_len, context_kv_len)
|
||||
new_mask = mask[:, context_kv_len:]
|
||||
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def flash_attn_varlen_nkifunc(
|
||||
query,
|
||||
key,
|
||||
@ -612,13 +828,32 @@ def flash_attn_varlen_nkifunc(
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
mixed_precision=True,
|
||||
):
|
||||
config = FlashConfig(
|
||||
seq_tile_size=LARGE_TILE_SZ,
|
||||
should_transpose_v=False,
|
||||
)
|
||||
"""
|
||||
Compute flash paged attention for variable length sequences.
|
||||
|
||||
This function is a wrapper around the flash attention NKI kernel. It takes
|
||||
in the following arguments:
|
||||
- query: (1, n_heads, d, seq_q)
|
||||
- key: (1, n_kv_heads, d, seq_k)
|
||||
- value: (1, n_kv_heads, seq_v, d)
|
||||
- key_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||
- value_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (n_active_blocks, )
|
||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||
|
||||
Notes:
|
||||
- attn_mask must be reordered outside using `reorder_context_mask`
|
||||
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
||||
for better DMA throughput
|
||||
"""
|
||||
if n_kv_head is None:
|
||||
n_kv_head = key_cache.shape[1]
|
||||
assert key_cache.shape[1] == n_kv_head
|
||||
if head_size is None:
|
||||
head_size = key_cache.shape[-1]
|
||||
|
||||
kwargs = dict(
|
||||
query=query,
|
||||
key=key,
|
||||
@ -628,15 +863,9 @@ def flash_attn_varlen_nkifunc(
|
||||
block_tables=block_table,
|
||||
mask=attn_mask,
|
||||
softmax_scale=1.0 / (head_size**0.5),
|
||||
config=config,
|
||||
mixed_precision=mixed_precision,
|
||||
return_debug_tensors=return_debug_tensors,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
)
|
||||
_, n_kv_head, _, _ = key.shape
|
||||
|
||||
if return_debug_tensors:
|
||||
o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o, *debug_tensors
|
||||
else:
|
||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o
|
||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||
return o
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user