[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
Lingfan Yu 2025-02-11 21:12:37 -08:00 committed by GitHub
parent 842b0fd402
commit e92694b6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 152 additions and 178 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional
import pytest
@ -171,12 +170,22 @@ def ref_context_attention(
return output
@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
@ -184,6 +193,8 @@ def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
@ -192,40 +203,46 @@ def test_contexted_kv_attention(
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
assert large_tile_size % block_size == 0
device = xm.xla_device()
os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' ")
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
min_ctx_len = 2
max_ctx_len = 64
min_query_len = 2
max_query_len = 64
prefill_batch_size = 2
decode_batch_size = 6
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
batch_size = prefill_batch_size + decode_batch_size
block_size = 32
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
ctx_lens = [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(prefill_batch_size)
] + [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(decode_batch_size)
]
query_lens = [
random.randint(min_query_len, max_query_len)
for _ in range(prefill_batch_size)
] + [1 for _ in range(decode_batch_size)]
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
dtype=torch.long).tolist()
decode_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (decode_batch_size, ),
dtype=torch.long).tolist()
ctx_lens = prefill_ctx_lens + decode_ctx_lens
query_lens = torch.randint(
min_query_len,
max_query_len + 1,
(prefill_batch_size, ),
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
@ -254,7 +271,6 @@ def test_contexted_kv_attention(
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
@ -311,9 +327,7 @@ def test_contexted_kv_attention(
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = 2048
max_num_queries = (
(sum(query_lens) + block_size - 1) // block_size) * block_size
LARGE_TILE_SZ = large_tile_size
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
@ -332,26 +346,28 @@ def test_contexted_kv_attention(
0,
)
def shift_bit_length(x):
return 1 << (x - 1).bit_length()
def ceil_div(a, b):
return (a + b - 1) // b
def pad_to_multiple(a, b):
return ceil_div(a, b) * b
def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries_shifted = shift_bit_length(max_num_queries)
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
assert (max_num_queries_padded == B_P_SIZE
), "invalid {max_num_queries_padded=}"
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks_shifted = shift_bit_length(
((context_lens + block_size - 1) // block_size).sum().item())
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
num_active_blocks_shifted)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
assert (num_active_blocks *
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
context_kv_len = num_active_blocks * block_size
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
@ -360,7 +376,7 @@ def test_contexted_kv_attention(
0,
0,
0,
max_num_queries_padded - query.shape[0],
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
@ -397,7 +413,7 @@ def test_contexted_kv_attention(
0,
context_kv_len - prior_mask.shape[1],
0,
B_P_SIZE - prior_mask.shape[0],
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
@ -406,9 +422,9 @@ def test_contexted_kv_attention(
active_mask,
(
0,
B_P_SIZE - active_mask.shape[1],
max_num_queries - active_mask.shape[1],
0,
B_P_SIZE - active_mask.shape[0],
max_num_queries - active_mask.shape[0],
),
"constant",
0,
@ -430,6 +446,8 @@ def test_contexted_kv_attention(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
)
if return_debug_tensors:
@ -439,17 +457,15 @@ def test_contexted_kv_attention(
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []
output_nki = torch.tensor(output_nki).cpu()
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
num_actual_tokens = sum(query_lens)
print(f"{num_actual_tokens=}")
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.permute(
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
"constant",
0,
)

View File

@ -28,7 +28,6 @@ class FlashConfig:
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
forward_mask,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
@ -60,30 +59,19 @@ def _flash_attention_core(
q_local_tile,
k,
v,
q_h_per_k_h,
seqlen_q,
nheads,
o_buffer,
l_buffer,
m_buffer,
batch_id,
head_id,
gqa_head_idx,
q_tile_idx,
local_k_large_tile_idx,
kernel_dtype,
acc_type,
flash_config: FlashConfig,
use_causal_mask=False,
continuous_batching_mask=None,
use_causal_mask,
tile_mask,
initialize=False,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=None,
):
"""
@ -99,24 +87,9 @@ def _flash_attention_core(
"""
LARGE_TILE_SZ = flash_config.seq_tile_size
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
seqlen_k = k.shape[-1]
seqlen_q // B_P_SIZE
seqlen_k // B_F_SIZE
# TODO : support logit_bias with continuous_batching_mask
assert not use_causal_mask, "causal mask is not supported."
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (
logit_bias_tile
is None), "continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arithmetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
@ -125,20 +98,27 @@ def _flash_attention_core(
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
if use_causal_mask:
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True,
mask=None) # (p(128), 512)
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
continuous_batching_mask[:, k_i_b_f_slice],
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
@ -147,7 +127,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
if qk_res_buffer is not None:
@ -159,7 +138,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
@ -170,8 +148,7 @@ def _flash_attention_core(
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_,
mask=forward_mask) # (128,1)
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
@ -180,11 +157,8 @@ def _flash_attention_core(
m_previous,
bias=-1 * m_current,
scale=1.0,
mask=forward_mask,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
alpha,
mask=forward_mask)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@ -207,10 +181,9 @@ def _flash_attention_core(
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
mask=forward_mask,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@ -218,7 +191,6 @@ def _flash_attention_core(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
forward_mask=forward_mask,
B_F_SIZE=B_F_SIZE,
)
@ -230,27 +202,20 @@ def _flash_attention_core(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[k_i, :, :],
transpose_x=True,
mask=forward_mask,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(
nl.subtract(l_prev, m_current, mask=forward_mask),
mask=forward_mask,
),
nl.exp(nl.subtract(l_prev, m_current)),
ps,
mask=forward_mask,
)
l_buffer[:, 0] = nl.add(m_current,
nl.log(l_exp, mask=forward_mask),
mask=forward_mask)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
)
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles):
(num_blocks, ) = block_tables_hbm.shape
assert num_blocks % num_tiles == 0
num_blocks_per_tile = num_blocks // num_tiles
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
return block_tables_buffer
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def flash_paged_attention(
query,
@ -415,18 +395,13 @@ def flash_paged_attention(
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert (num_blocks_per_large_tile <= B_P_SIZE
), f"The number of blocks in each large tile " \
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
assert is_power_of_2(
num_blocks_per_large_tile
), "The number of blocks in each large tile is expected of be power of 2"
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
0,
dtype=np.int32,
buffer=nl.sbuf)
for j in nl.affine_range(num_large_k_tile):
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
block_tables_sbuf[i_p, j] = nl.load(
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
@ -457,7 +432,7 @@ def flash_paged_attention(
)
for k_i in nl.affine_range(num_blocks_per_large_tile):
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
head_id, :])
cur_k_tile[:, nl.ds(k_i *
block_size, block_size)] = nl.transpose(loaded)
@ -469,7 +444,7 @@ def flash_paged_attention(
num_blocks_per_partition):
v_i = (partition_idx * num_blocks_per_partition +
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
head_id, :])
cur_v_tile[
partition_idx,
@ -477,14 +452,15 @@ def flash_paged_attention(
:,
] = loaded_v
for i in nl.affine_range(n_tile_q):
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@ -497,35 +473,24 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=j,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
tile_mask=cur_mask,
initialize=j == 0,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = seqlen_q
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
active_config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
@ -552,11 +517,16 @@ def flash_paged_attention(
config=active_config,
)
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(
mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
],
dtype=mask.dtype,
)
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@ -568,32 +538,21 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=0,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=active_config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
use_causal_mask=True,
tile_mask=cur_mask,
initialize=False,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
attn_mask,
n_kv_head=None,
head_size=None,
B_P_SIZE=128,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
mixed_precision=True,