mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921)
Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
parent
842b0fd402
commit
e92694b6fe
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -171,12 +170,22 @@ def ref_context_attention(
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size, large_tile_size",
|
||||
[
|
||||
(32, 2048), # 64 blocks
|
||||
(32, 4096), # 128 blocks
|
||||
(32, 8192), # 256 blocks
|
||||
(64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||
[
|
||||
(4, 2, 8, False),
|
||||
(4, 2, 8, True),
|
||||
(32, 8, 64, True),
|
||||
(16, 2, 128, True),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
@ -184,6 +193,8 @@ def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
@ -192,40 +203,46 @@ def test_contexted_kv_attention(
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
os.environ["NEURON_CC_FLAGS"] = (
|
||||
" --model-type=transformer -O1 "
|
||||
" --internal-hlo2tensorizer-options='--verify-hlo' ")
|
||||
compiler_flags = [
|
||||
"--model-type=transformer -O1",
|
||||
"--internal-hlo2tensorizer-options='--verify-hlo'",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
min_ctx_len = 2
|
||||
max_ctx_len = 64
|
||||
min_query_len = 2
|
||||
max_query_len = 64
|
||||
prefill_batch_size = 2
|
||||
decode_batch_size = 6
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
prefill_batch_size = 4
|
||||
decode_batch_size = 12
|
||||
batch_size = prefill_batch_size + decode_batch_size
|
||||
block_size = 32
|
||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||
|
||||
max_block_per_request = max_model_len // block_size
|
||||
dtype = torch.float32
|
||||
cache_size = (batch_size * max_block_per_request) + 2
|
||||
ctx_lens = [
|
||||
random.randint(min_ctx_len, max_ctx_len)
|
||||
for _ in range(prefill_batch_size)
|
||||
] + [
|
||||
random.randint(min_ctx_len, max_ctx_len)
|
||||
for _ in range(decode_batch_size)
|
||||
]
|
||||
query_lens = [
|
||||
random.randint(min_query_len, max_query_len)
|
||||
for _ in range(prefill_batch_size)
|
||||
] + [1 for _ in range(decode_batch_size)]
|
||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (prefill_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
decode_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (decode_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
ctx_lens = prefill_ctx_lens + decode_ctx_lens
|
||||
query_lens = torch.randint(
|
||||
min_query_len,
|
||||
max_query_len + 1,
|
||||
(prefill_batch_size, ),
|
||||
dtype=torch.long,
|
||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
@ -254,7 +271,6 @@ def test_contexted_kv_attention(
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:batch_size * max_block_per_request].view(
|
||||
batch_size, max_block_per_request)
|
||||
torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
@ -311,9 +327,7 @@ def test_contexted_kv_attention(
|
||||
# build neuron program
|
||||
return_debug_tensors = False
|
||||
B_P_SIZE = 128
|
||||
LARGE_TILE_SZ = 2048
|
||||
max_num_queries = (
|
||||
(sum(query_lens) + block_size - 1) // block_size) * block_size
|
||||
LARGE_TILE_SZ = large_tile_size
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
@ -332,26 +346,28 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
)
|
||||
|
||||
def shift_bit_length(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def pad_to_multiple(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
def pad_to_next_power_of_2(a):
|
||||
assert a > 0
|
||||
return 2**int(a - 1).bit_length()
|
||||
|
||||
# calculate input shapes
|
||||
max_num_queries_shifted = shift_bit_length(max_num_queries)
|
||||
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
|
||||
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
|
||||
assert (max_num_queries_padded == B_P_SIZE
|
||||
), "invalid {max_num_queries_padded=}"
|
||||
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
|
||||
max_num_queries = pad_to_next_power_of_2(max_num_queries)
|
||||
head_size_padded = B_P_SIZE
|
||||
assert head_size_padded >= head_size
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks_shifted = shift_bit_length(
|
||||
((context_lens + block_size - 1) // block_size).sum().item())
|
||||
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
|
||||
num_active_blocks_shifted)
|
||||
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
|
||||
assert (num_active_blocks *
|
||||
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
|
||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
LARGE_TILE_SZ // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
|
||||
assert (context_kv_len %
|
||||
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
|
||||
# pad QKV tensors
|
||||
pad_dims = (
|
||||
@ -360,7 +376,7 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
max_num_queries_padded - query.shape[0],
|
||||
max_num_queries - query.shape[0],
|
||||
)
|
||||
query = F.pad(query, pad_dims, "constant", 0)
|
||||
k = F.pad(k, pad_dims, "constant", 0)
|
||||
@ -397,7 +413,7 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
B_P_SIZE - prior_mask.shape[0],
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
@ -406,9 +422,9 @@ def test_contexted_kv_attention(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
B_P_SIZE - active_mask.shape[1],
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
B_P_SIZE - active_mask.shape[0],
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
@ -430,6 +446,8 @@ def test_contexted_kv_attention(
|
||||
n_kv_head=num_kv_heads,
|
||||
head_size=head_size,
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
return_debug_tensors=return_debug_tensors,
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
@ -439,17 +457,15 @@ def test_contexted_kv_attention(
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
debug_tensors = []
|
||||
|
||||
output_nki = torch.tensor(output_nki).cpu()
|
||||
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
||||
|
||||
num_actual_tokens = sum(query_lens)
|
||||
print(f"{num_actual_tokens=}")
|
||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||
output_nki = output_nki.permute(
|
||||
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
|
||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||
output_ref_padded = F.pad(
|
||||
output_ref,
|
||||
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
|
||||
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
@ -28,7 +28,6 @@ class FlashConfig:
|
||||
def transpose_p_local(p_local_transposed,
|
||||
p_local,
|
||||
LARGE_TILE_SZ,
|
||||
forward_mask,
|
||||
B_F_SIZE=512):
|
||||
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||
p_local[:, i_j_128_slice])
|
||||
else:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||
p_local[:, i_j_128_slice])
|
||||
|
||||
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -60,36 +59,25 @@ def _flash_attention_core(
|
||||
q_local_tile,
|
||||
k,
|
||||
v,
|
||||
q_h_per_k_h,
|
||||
seqlen_q,
|
||||
nheads,
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
batch_id,
|
||||
head_id,
|
||||
gqa_head_idx,
|
||||
q_tile_idx,
|
||||
local_k_large_tile_idx,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
flash_config: FlashConfig,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=None,
|
||||
use_causal_mask,
|
||||
tile_mask,
|
||||
initialize=False,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
qk_res_buffer=None,
|
||||
):
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||
already. The block size of K and V
|
||||
is defined in the seq_tile_size of the flash_config. The results are stored
|
||||
in the following three buffers
|
||||
@ -99,24 +87,9 @@ def _flash_attention_core(
|
||||
"""
|
||||
LARGE_TILE_SZ = flash_config.seq_tile_size
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
seqlen_k = k.shape[-1]
|
||||
seqlen_q // B_P_SIZE
|
||||
seqlen_k // B_F_SIZE
|
||||
|
||||
# TODO : support logit_bias with continuous_batching_mask
|
||||
assert not use_causal_mask, "causal mask is not supported."
|
||||
assert (continuous_batching_mask
|
||||
is not None), "continuous_batching_mask input is required."
|
||||
if continuous_batching_mask is not None:
|
||||
assert (
|
||||
logit_bias_tile
|
||||
is None), "continuous_batching_mask does not support logit_bias!"
|
||||
|
||||
# mask are used to only apply computation to the lower half of the matrix,
|
||||
# which reduce the arithmetic intensity by half
|
||||
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
||||
LARGE_TILE_SZ if use_causal_mask else None)
|
||||
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
@ -125,20 +98,27 @@ def _flash_attention_core(
|
||||
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True,
|
||||
mask=None) # (p(128), 512)
|
||||
if use_causal_mask:
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
multiplication_required_selection = True
|
||||
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
continuous_batching_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
if multiplication_required_selection:
|
||||
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True) # (p(128), 512)
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
tile_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
else:
|
||||
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
||||
|
||||
# Calculate max of the current tile
|
||||
max_local[:, k_i] = nisa.tensor_reduce(
|
||||
@ -147,7 +127,6 @@ def _flash_attention_core(
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
if qk_res_buffer is not None:
|
||||
@ -159,7 +138,6 @@ def _flash_attention_core(
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
@ -170,8 +148,7 @@ def _flash_attention_core(
|
||||
m_current = max_
|
||||
else:
|
||||
m_previous = nl.copy(m_buffer[:, 0])
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_,
|
||||
mask=forward_mask) # (128,1)
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
||||
|
||||
m_current = m_buffer[:, 0]
|
||||
# Compute scaling factor
|
||||
@ -180,11 +157,8 @@ def _flash_attention_core(
|
||||
m_previous,
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
mask=forward_mask,
|
||||
)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
|
||||
alpha,
|
||||
mask=forward_mask)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
||||
|
||||
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
@ -207,10 +181,9 @@ def _flash_attention_core(
|
||||
reduce_op=nl.add,
|
||||
reduce_res=p_partial_sum[:, k_r_i],
|
||||
dtype=kernel_dtype,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
||||
|
||||
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
@ -218,7 +191,6 @@ def _flash_attention_core(
|
||||
p_local_transposed=p_local_transposed,
|
||||
p_local=p_local,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
forward_mask=forward_mask,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
@ -230,27 +202,20 @@ def _flash_attention_core(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[k_i, :, :],
|
||||
transpose_x=True,
|
||||
mask=forward_mask,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
if initialize:
|
||||
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||
else:
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
||||
|
||||
l_prev = l_buffer[:, 0]
|
||||
l_exp = nl.add(
|
||||
nl.exp(
|
||||
nl.subtract(l_prev, m_current, mask=forward_mask),
|
||||
mask=forward_mask,
|
||||
),
|
||||
nl.exp(nl.subtract(l_prev, m_current)),
|
||||
ps,
|
||||
mask=forward_mask,
|
||||
)
|
||||
l_buffer[:, 0] = nl.add(m_current,
|
||||
nl.log(l_exp, mask=forward_mask),
|
||||
mask=forward_mask)
|
||||
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
|
||||
)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles):
|
||||
(num_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks % num_tiles == 0
|
||||
num_blocks_per_tile = num_blocks // num_tiles
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
|
||||
return block_tables_buffer
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def flash_paged_attention(
|
||||
query,
|
||||
@ -316,24 +296,24 @@ def flash_paged_attention(
|
||||
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
block_tables (int32) and mask (int32)
|
||||
- If mixed_percision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
- If mixed_percision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
Otherwise the intermediates will be in the same type as the inputs.
|
||||
|
||||
Compile-time Constants:
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
|
||||
with Performance config parameters for flash attention with default
|
||||
values
|
||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
nheads
|
||||
|
||||
Example usage:
|
||||
@ -415,18 +395,13 @@ def flash_paged_attention(
|
||||
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert (num_blocks_per_large_tile <= B_P_SIZE
|
||||
), f"The number of blocks in each large tile " \
|
||||
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
|
||||
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), "The number of blocks in each large tile is expected of be power of 2"
|
||||
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
|
||||
|
||||
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
|
||||
0,
|
||||
dtype=np.int32,
|
||||
buffer=nl.sbuf)
|
||||
for j in nl.affine_range(num_large_k_tile):
|
||||
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
|
||||
block_tables_sbuf[i_p, j] = nl.load(
|
||||
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
|
||||
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
@ -457,7 +432,7 @@ def flash_paged_attention(
|
||||
)
|
||||
|
||||
for k_i in nl.affine_range(num_blocks_per_large_tile):
|
||||
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
|
||||
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
|
||||
head_id, :])
|
||||
cur_k_tile[:, nl.ds(k_i *
|
||||
block_size, block_size)] = nl.transpose(loaded)
|
||||
@ -469,7 +444,7 @@ def flash_paged_attention(
|
||||
num_blocks_per_partition):
|
||||
v_i = (partition_idx * num_blocks_per_partition +
|
||||
block_in_partition)
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
|
||||
head_id, :])
|
||||
cur_v_tile[
|
||||
partition_idx,
|
||||
@ -477,14 +452,15 @@ def flash_paged_attention(
|
||||
:,
|
||||
] = loaded_v
|
||||
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
|
||||
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
|
||||
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
@ -497,35 +473,24 @@ def flash_paged_attention(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
q_h_per_k_h=q_h_per_k_h,
|
||||
seqlen_q=seqlen_q,
|
||||
nheads=h,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
batch_id=batch_id,
|
||||
head_id=head_id,
|
||||
gqa_head_idx=i_q_h,
|
||||
q_tile_idx=i,
|
||||
local_k_large_tile_idx=j,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=config,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=cur_mask,
|
||||
tile_mask=cur_mask,
|
||||
initialize=j == 0,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
)
|
||||
|
||||
# compute attention between input query, key and value
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = seqlen_q
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
active_config = FlashConfig(
|
||||
seq_tile_size=LARGE_TILE_SZ,
|
||||
@ -552,11 +517,16 @@ def flash_paged_attention(
|
||||
config=active_config,
|
||||
)
|
||||
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
|
||||
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(
|
||||
mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
],
|
||||
dtype=mask.dtype,
|
||||
)
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
@ -568,32 +538,21 @@ def flash_paged_attention(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
q_h_per_k_h=q_h_per_k_h,
|
||||
seqlen_q=seqlen_q,
|
||||
nheads=h,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
batch_id=batch_id,
|
||||
head_id=head_id,
|
||||
gqa_head_idx=i_q_h,
|
||||
q_tile_idx=i,
|
||||
local_k_large_tile_idx=0,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=active_config,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
tile_mask=cur_mask,
|
||||
initialize=False,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
qk_res_buffer=qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None,
|
||||
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None),
|
||||
)
|
||||
|
||||
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||
@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
B_P_SIZE=128,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
mixed_precision=True,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user