[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
Lingfan Yu 2025-02-11 21:12:37 -08:00 committed by GitHub
parent 842b0fd402
commit e92694b6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 152 additions and 178 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional
import pytest
@ -171,12 +170,22 @@ def ref_context_attention(
return output
@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
@ -184,6 +193,8 @@ def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
@ -192,40 +203,46 @@ def test_contexted_kv_attention(
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
assert large_tile_size % block_size == 0
device = xm.xla_device()
os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' ")
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
min_ctx_len = 2
max_ctx_len = 64
min_query_len = 2
max_query_len = 64
prefill_batch_size = 2
decode_batch_size = 6
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
batch_size = prefill_batch_size + decode_batch_size
block_size = 32
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
ctx_lens = [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(prefill_batch_size)
] + [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(decode_batch_size)
]
query_lens = [
random.randint(min_query_len, max_query_len)
for _ in range(prefill_batch_size)
] + [1 for _ in range(decode_batch_size)]
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
dtype=torch.long).tolist()
decode_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (decode_batch_size, ),
dtype=torch.long).tolist()
ctx_lens = prefill_ctx_lens + decode_ctx_lens
query_lens = torch.randint(
min_query_len,
max_query_len + 1,
(prefill_batch_size, ),
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
@ -254,7 +271,6 @@ def test_contexted_kv_attention(
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
@ -311,9 +327,7 @@ def test_contexted_kv_attention(
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = 2048
max_num_queries = (
(sum(query_lens) + block_size - 1) // block_size) * block_size
LARGE_TILE_SZ = large_tile_size
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
@ -332,26 +346,28 @@ def test_contexted_kv_attention(
0,
)
def shift_bit_length(x):
return 1 << (x - 1).bit_length()
def ceil_div(a, b):
return (a + b - 1) // b
def pad_to_multiple(a, b):
return ceil_div(a, b) * b
def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries_shifted = shift_bit_length(max_num_queries)
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
assert (max_num_queries_padded == B_P_SIZE
), "invalid {max_num_queries_padded=}"
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks_shifted = shift_bit_length(
((context_lens + block_size - 1) // block_size).sum().item())
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
num_active_blocks_shifted)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
assert (num_active_blocks *
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
context_kv_len = num_active_blocks * block_size
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
@ -360,7 +376,7 @@ def test_contexted_kv_attention(
0,
0,
0,
max_num_queries_padded - query.shape[0],
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
@ -397,7 +413,7 @@ def test_contexted_kv_attention(
0,
context_kv_len - prior_mask.shape[1],
0,
B_P_SIZE - prior_mask.shape[0],
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
@ -406,9 +422,9 @@ def test_contexted_kv_attention(
active_mask,
(
0,
B_P_SIZE - active_mask.shape[1],
max_num_queries - active_mask.shape[1],
0,
B_P_SIZE - active_mask.shape[0],
max_num_queries - active_mask.shape[0],
),
"constant",
0,
@ -430,6 +446,8 @@ def test_contexted_kv_attention(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
)
if return_debug_tensors:
@ -439,17 +457,15 @@ def test_contexted_kv_attention(
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []
output_nki = torch.tensor(output_nki).cpu()
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
num_actual_tokens = sum(query_lens)
print(f"{num_actual_tokens=}")
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.permute(
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
"constant",
0,
)

View File

@ -28,7 +28,6 @@ class FlashConfig:
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
forward_mask,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
@ -60,36 +59,25 @@ def _flash_attention_core(
q_local_tile,
k,
v,
q_h_per_k_h,
seqlen_q,
nheads,
o_buffer,
l_buffer,
m_buffer,
batch_id,
head_id,
gqa_head_idx,
q_tile_idx,
local_k_large_tile_idx,
kernel_dtype,
acc_type,
flash_config: FlashConfig,
use_causal_mask=False,
continuous_batching_mask=None,
use_causal_mask,
tile_mask,
initialize=False,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
already. The block size of K and V
is defined in the seq_tile_size of the flash_config. The results are stored
in the following three buffers
@ -99,24 +87,9 @@ def _flash_attention_core(
"""
LARGE_TILE_SZ = flash_config.seq_tile_size
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
seqlen_k = k.shape[-1]
seqlen_q // B_P_SIZE
seqlen_k // B_F_SIZE
# TODO : support logit_bias with continuous_batching_mask
assert not use_causal_mask, "causal mask is not supported."
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (
logit_bias_tile
is None), "continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arithmetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
@ -125,20 +98,27 @@ def _flash_attention_core(
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True,
mask=None) # (p(128), 512)
if use_causal_mask:
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
qk_res_buf[:, k_i_b_f_slice] = nl.where(
continuous_batching_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
@ -147,7 +127,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
if qk_res_buffer is not None:
@ -159,7 +138,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
@ -170,8 +148,7 @@ def _flash_attention_core(
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_,
mask=forward_mask) # (128,1)
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
@ -180,11 +157,8 @@ def _flash_attention_core(
m_previous,
bias=-1 * m_current,
scale=1.0,
mask=forward_mask,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
alpha,
mask=forward_mask)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@ -207,10 +181,9 @@ def _flash_attention_core(
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
mask=forward_mask,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@ -218,7 +191,6 @@ def _flash_attention_core(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
forward_mask=forward_mask,
B_F_SIZE=B_F_SIZE,
)
@ -230,27 +202,20 @@ def _flash_attention_core(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[k_i, :, :],
transpose_x=True,
mask=forward_mask,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(
nl.subtract(l_prev, m_current, mask=forward_mask),
mask=forward_mask,
),
nl.exp(nl.subtract(l_prev, m_current)),
ps,
mask=forward_mask,
)
l_buffer[:, 0] = nl.add(m_current,
nl.log(l_exp, mask=forward_mask),
mask=forward_mask)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
)
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles):
(num_blocks, ) = block_tables_hbm.shape
assert num_blocks % num_tiles == 0
num_blocks_per_tile = num_blocks // num_tiles
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
return block_tables_buffer
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def flash_paged_attention(
query,
@ -316,24 +296,24 @@ def flash_paged_attention(
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
is set to `true`, if false, we use same precision as input types
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
with Performance config parameters for flash attention with default
values
seq_tile_size: `default=2048`, size of the kv tile size for attention
seq_tile_size: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
@ -415,18 +395,13 @@ def flash_paged_attention(
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert (num_blocks_per_large_tile <= B_P_SIZE
), f"The number of blocks in each large tile " \
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
assert is_power_of_2(
num_blocks_per_large_tile
), "The number of blocks in each large tile is expected of be power of 2"
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
0,
dtype=np.int32,
buffer=nl.sbuf)
for j in nl.affine_range(num_large_k_tile):
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
block_tables_sbuf[i_p, j] = nl.load(
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
@ -457,7 +432,7 @@ def flash_paged_attention(
)
for k_i in nl.affine_range(num_blocks_per_large_tile):
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
head_id, :])
cur_k_tile[:, nl.ds(k_i *
block_size, block_size)] = nl.transpose(loaded)
@ -469,7 +444,7 @@ def flash_paged_attention(
num_blocks_per_partition):
v_i = (partition_idx * num_blocks_per_partition +
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
head_id, :])
cur_v_tile[
partition_idx,
@ -477,14 +452,15 @@ def flash_paged_attention(
:,
] = loaded_v
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
for i in nl.affine_range(n_tile_q):
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@ -497,35 +473,24 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=j,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
tile_mask=cur_mask,
initialize=j == 0,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = seqlen_q
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
active_config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
@ -552,11 +517,16 @@ def flash_paged_attention(
config=active_config,
)
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(
mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
],
dtype=mask.dtype,
)
for i_q_h in nl.affine_range(q_h_per_k_h):
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@ -568,32 +538,21 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=0,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=active_config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
use_causal_mask=True,
tile_mask=cur_mask,
initialize=False,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
attn_mask,
n_kv_head=None,
head_size=None,
B_P_SIZE=128,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
mixed_precision=True,