diff --git a/tests/neuron/test_block_table.py b/tests/neuron/test_block_table.py new file mode 100644 index 0000000000000..30dcdd573edf3 --- /dev/null +++ b/tests/neuron/test_block_table.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import neuronxcc.nki.language as nl +import pytest +import torch +import torch.nn.functional as F +from neuronxcc import nki + +from vllm.attention.ops.nki_flash_attn import ( + load_block_tables, transform_block_tables_for_indirect_load) + + +def is_power_of_2(n): + return n > 0 and (n & (n - 1) == 0) + + +def nki_load_and_transform_block_tables( + block_tables, + num_tiles, + num_blocks_per_tile, + num_head, + head_id, + block_size_tiling_factor, +): + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" + block_tables_sbuf = load_block_tables(block_tables, num_tiles, + num_blocks_per_tile) + + # we need to pass an Index as head_id + head_id = nl.arange(1)[None, :] + head_id + + block_tables_transposed = transform_block_tables_for_indirect_load( + block_tables_sbuf, block_size_tiling_factor, num_head, head_id) + B_P_SIZE = 128 + assert block_tables_transposed.shape[1] == B_P_SIZE + + out = nl.ndarray( + block_tables_transposed.shape, + dtype=nl.int32, + buffer=nl.shared_hbm, + ) + for i in nl.affine_range(block_tables_transposed.shape[0]): + nl.store(dst=out[i], value=block_tables_transposed[i]) + return out + + +def ref_block_tables_transform( + block_tables, + num_tiles, + num_blocks_per_tile, + num_head, + head_id, + block_size_tiling_factor, +): + assert block_tables.numel() == num_tiles * num_blocks_per_tile + block_tables = block_tables.view(num_tiles, num_blocks_per_tile) + B_F_SIZE = 128 + num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE + block_tables = F.pad( + block_tables, + (0, 0, 0, num_tiles_padded - num_tiles), + "constant", + 0, + ) + + block_tables = block_tables * num_head + head_id + block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) + offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) + block_tables = block_tables * block_size_tiling_factor + offset + block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() + + num_blocks_per_tile = block_tables_transposed.shape[0] + assert num_blocks_per_tile % B_F_SIZE == 0 + return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, + B_F_SIZE, num_tiles_padded) + + +@pytest.mark.parametrize( + "q_head_per_kv_head,head_id", + [ + (1, 0), + (3, 1), + ], +) +@pytest.mark.parametrize( + "num_tiles,num_blocks_per_tile", + [ + (1, 1), + (13, 16), + (17, 128), + (35, 512), + (128, 128), + (130, 64), + (280, 256), + (315, 1), + ], +) +@torch.inference_mode() +def test_load_and_transform_block_tables( + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, +) -> None: + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + compiler_flags = [ + "-O1", + "--retry_failed_compilation", + ] + compiler_flags_str = " ".join(compiler_flags) + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str + + torch.manual_seed(10000) + torch.set_printoptions(sci_mode=False) + + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + B_P_SIZE = 128 + if num_blocks_per_tile < B_P_SIZE: + assert B_P_SIZE % num_blocks_per_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile + else: + block_size_tiling_factor = 1 + max_num_blocks = 100000 + block_tables = torch.randint( + 0, + max_num_blocks, + (num_tiles * num_blocks_per_tile, ), + dtype=torch.int32, + ) + nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( + block_tables.to(device=device), + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, + block_size_tiling_factor, + ).cpu() + ref_out = ref_block_tables_transform( + block_tables, + num_tiles, + num_blocks_per_tile, + q_head_per_kv_head, + head_id, + block_size_tiling_factor, + ) + assert (nki_out.shape == ref_out.shape + ), f"{nki_out.shape=} != {ref_out.shape=}" + assert torch.all(nki_out == ref_out) diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py index 04d1bd3f0eb04..347a139f39b4e 100644 --- a/tests/neuron/test_prefix_prefill.py +++ b/tests/neuron/test_prefix_prefill.py @@ -107,7 +107,7 @@ def ref_masked_attention( masked_score, dim=-1, return_max_reduce=True) else: norm_score = ref_softmax(masked_score, dim=-1) - out = torch.einsum("hqk,khd->qhd", norm_score, value) + out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value) if return_max_reduce: return ( out, @@ -118,7 +118,7 @@ def ref_masked_attention( scaled_qk, ) else: - return out + return (out, ) def ref_context_attention( @@ -128,8 +128,6 @@ def ref_context_attention( query_lens, seq_lens, head_size, - num_kv_heads, - num_heads, num_queries_per_kv, return_max_reduce=False, ): @@ -146,18 +144,19 @@ def ref_context_attention( attn_mask = torch.logical_not(attn_mask) attn_mask = attn_mask.float() * -30000 - output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( - ref_masked_attention( - query, - key, - value, - scale, - attn_mask, - return_max_reduce=return_max_reduce, - )) + output, *debug_tensors = ref_masked_attention( + query, + key, + value, + scale, + attn_mask, + return_max_reduce=return_max_reduce, + ) output = output.unsqueeze(1) if return_max_reduce: + cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( + debug_tensors) return ( output, cached_max, @@ -170,65 +169,22 @@ def ref_context_attention( return output -@pytest.mark.parametrize( - "block_size, large_tile_size", - [ - (32, 2048), # 64 blocks - (32, 4096), # 128 blocks - (32, 8192), # 256 blocks - (64, 8192), # 128 blocks - ], -) -@pytest.mark.parametrize( - "num_heads,num_queries_per_kv,head_size,mixed_precision", - [ - (4, 2, 8, False), - (4, 2, 8, True), - (32, 8, 64, True), - (16, 2, 128, True), - ], -) -@torch.inference_mode() -def test_contexted_kv_attention( - num_heads: int, - num_queries_per_kv: int, - head_size: int, - block_size: int, - large_tile_size, - mixed_precision: bool, -) -> None: - import os - - import torch_xla.core.xla_model as xm - - from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc - - assert large_tile_size % block_size == 0 - - device = xm.xla_device() - - compiler_flags = [ - "--model-type=transformer -O1", - "--internal-hlo2tensorizer-options='--verify-hlo'", - "--retry_failed_compilation", - ] - compiler_flags_str = " ".join(compiler_flags) - os.environ["NEURON_CC_FLAGS"] = compiler_flags_str - - torch.manual_seed(0) - torch.set_printoptions(sci_mode=False) - - min_ctx_len = 32 - max_ctx_len = 1024 - min_query_len = 16 - max_query_len = 512 - prefill_batch_size = 4 - decode_batch_size = 12 +def sample_inputs( + prefill_batch_size, + decode_batch_size, + min_query_len, + max_query_len, + min_ctx_len, + max_ctx_len, + block_size, + num_heads, + num_kv_heads, + head_size, + dtype, +): batch_size = prefill_batch_size + decode_batch_size max_model_len = (max_query_len + max_ctx_len) * 4 - max_block_per_request = max_model_len // block_size - dtype = torch.float32 cache_size = (batch_size * max_block_per_request) + 2 prefill_ctx_lens = torch.randint(min_ctx_len, max_ctx_len + 1, (prefill_batch_size, ), @@ -244,7 +200,6 @@ def test_contexted_kv_attention( dtype=torch.long, ).tolist() + [1 for _ in range(decode_batch_size)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -304,47 +259,139 @@ def test_contexted_kv_attention( cur_ctx += block_size block_id += 1 + return ( + query, + k, + v, + k_cache, + v_cache, + block_table, + key, + value, + query_lens, + seq_lens, + ) + + +def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, + num_blocks): + context_lens = seq_lens - query_lens + blocks_per_seq = (context_lens + block_size - 1) // block_size + num_seqs = len(seq_lens) + active_blocks: list[int] = [] + for seq_id in range(num_seqs): + active_blocks = ( + active_blocks + + block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) + return F.pad( + torch.tensor(active_blocks, dtype=torch.int32), + (0, num_blocks - len(active_blocks)), + "constant", + 0, + ) + + +@pytest.mark.parametrize( + "prefill_batch_size,decode_batch_size,block_size,large_tile_size", + [ + (1, 199, 1, 512), # 512 blocks + (4, 12, 256, 2048), # 128 blocks + (4, 12, 16, 2048), # 128 blocks + (4, 12, 4, 1024), # 256 blocks + (4, 12, 32, 2048), # 64 blocks + (4, 12, 32, 4096), # 128 blocks + (4, 12, 32, 8192), # 256 blocks + (4, 12, 64, 8192), # 128 blocks + ], +) +@pytest.mark.parametrize( + "num_heads,num_queries_per_kv,head_size", + [ + (4, 2, 8), + (32, 8, 64), + (4, 4, 128), + (8, 1, 32), + ], +) +@pytest.mark.parametrize("mixed_precision", [True, False]) +@torch.inference_mode() +def test_contexted_kv_attention( + prefill_batch_size: int, + decode_batch_size: int, + num_heads: int, + num_queries_per_kv: int, + head_size: int, + block_size: int, + large_tile_size, + mixed_precision: bool, +) -> None: + import os + + import torch_xla.core.xla_model as xm + + from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc, + reorder_context_mask) + + assert large_tile_size % block_size == 0 + + device = xm.xla_device() + + compiler_flags = [ + "-O1", + "--retry_failed_compilation", + ] + compiler_flags_str = " ".join(compiler_flags) + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str + + torch.manual_seed(0) + torch.set_printoptions(sci_mode=False) + dtype = torch.float32 + + min_ctx_len = 32 + max_ctx_len = 1024 + min_query_len = 16 + max_query_len = 512 + num_kv_heads = num_heads // num_queries_per_kv ( - output_ref, - cached_max, - cached_sum_reciprocal, - lse, - masked_score, - scaled_qk, - ) = ref_context_attention( + query, + k_active, + v_active, + k_cache, + v_cache, + block_table, + key, + value, + query_lens, + seq_lens, + ) = sample_inputs( + prefill_batch_size=prefill_batch_size, + decode_batch_size=decode_batch_size, + min_query_len=min_query_len, + max_query_len=max_query_len, + min_ctx_len=min_ctx_len, + max_ctx_len=max_ctx_len, + block_size=block_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + ) + + output_ref = ref_context_attention( query, key, value, query_lens, seq_lens, head_size, - num_kv_heads, - num_heads, num_queries_per_kv, - return_max_reduce=True, + return_max_reduce=False, ) # build neuron program - return_debug_tensors = False B_P_SIZE = 128 - LARGE_TILE_SZ = large_tile_size - - def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, - num_blocks): - context_lens = seq_lens - query_lens - blocks_per_seq = (context_lens + block_size - 1) // block_size - num_seqs = len(seq_lens) - active_blocks: list[int] = [] - for seq_id in range(num_seqs): - active_blocks = ( - active_blocks + - block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) - return F.pad( - torch.tensor(active_blocks), - (0, num_blocks - len(active_blocks)), - "constant", - 0, - ) + assert (large_tile_size >= B_P_SIZE + ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" def ceil_div(a, b): return (a + b - 1) // b @@ -357,32 +404,27 @@ def test_contexted_kv_attention( return 2**int(a - 1).bit_length() # calculate input shapes - max_num_queries = pad_to_multiple(sum(query_lens), block_size) - max_num_queries = pad_to_next_power_of_2(max_num_queries) - head_size_padded = B_P_SIZE - assert head_size_padded >= head_size + max_num_queries = pad_to_next_power_of_2(sum(query_lens)) context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) num_active_blocks = ceil_div(context_lens, block_size).sum().item() num_active_blocks = pad_to_multiple(num_active_blocks, - LARGE_TILE_SZ // block_size) + large_tile_size // block_size) context_kv_len = num_active_blocks * block_size assert (context_kv_len % - LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}" + large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" # pad QKV tensors pad_dims = ( 0, - head_size_padded - query.shape[2], + 0, 0, 0, 0, max_num_queries - query.shape[0], ) query = F.pad(query, pad_dims, "constant", 0) - k = F.pad(k, pad_dims, "constant", 0) - v = F.pad(v, pad_dims, "constant", 0) - k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0) - v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0) + k = F.pad(k_active, pad_dims, "constant", 0) + v = F.pad(v_active, pad_dims, "constant", 0) # permute QKV tensors # query: (1, n_heads, d, seq_q) @@ -391,6 +433,8 @@ def test_contexted_kv_attention( query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() + k_cache = k_cache.permute(0, 2, 1, 3).contiguous() + v_cache = v_cache.permute(0, 2, 1, 3).contiguous() # transform block table active_block_table = get_active_block_tables( @@ -405,33 +449,31 @@ def test_contexted_kv_attention( prior_mask, active_mask = ( BlockDiagonalCausalFromBottomRightMask.from_seqlens( query_lens, seq_lens, block_size=block_size)) - attn_mask = torch.concat( - [ - F.pad( - prior_mask, - ( - 0, - context_kv_len - prior_mask.shape[1], - 0, - max_num_queries - prior_mask.shape[0], - ), - "constant", - 0, - ).bool(), - F.pad( - active_mask, - ( - 0, - max_num_queries - active_mask.shape[1], - 0, - max_num_queries - active_mask.shape[0], - ), - "constant", - 0, - ).bool(), - ], - dim=1, - ) + prior_mask_padded = F.pad( + prior_mask, + ( + 0, + context_kv_len - prior_mask.shape[1], + 0, + max_num_queries - prior_mask.shape[0], + ), + "constant", + 0, + ).bool() + active_mask_padded = F.pad( + active_mask, + ( + 0, + max_num_queries - active_mask.shape[1], + 0, + max_num_queries - active_mask.shape[0], + ), + "constant", + 0, + ).bool() + attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1) + + attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size) input_args = ( query.to(device=device), @@ -439,29 +481,21 @@ def test_contexted_kv_attention( v.to(device=device), k_cache.to(device=device), v_cache.to(device=device), - active_block_table.to(torch.int32).to(device=device), + active_block_table.to(device=device), attn_mask.to(device=device), ) input_kwargs = dict( n_kv_head=num_kv_heads, head_size=head_size, mixed_precision=mixed_precision, - LARGE_TILE_SZ=LARGE_TILE_SZ, - return_debug_tensors=return_debug_tensors, + LARGE_TILE_SZ=large_tile_size, ) - if return_debug_tensors: - output_nki, *debug_tensors = flash_attn_varlen_nkifunc( - *input_args, **input_kwargs) - else: - output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) - debug_tensors = [] - - debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors] + output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) num_actual_tokens = sum(query_lens) # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size] + output_nki = output_nki.cpu().permute(0, 2, 1, 3) output_nki = output_nki[0, :num_actual_tokens, :, :] output_ref_padded = F.pad( output_ref, diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index 5e2a1f7e66d1f..20f9dcd163fea 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -1,27 +1,203 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass - import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np +import torch from neuronxcc import nki from neuronxcc.nki.language import par_dim -@dataclass(frozen=True) -class FlashConfig: - """ - Config class for flash attention with default values - """ +def ceil_div(a, b): + return (a + b - 1) // b - seq_tile_size: int = 2048 - should_transpose_v: bool = False - __annotations__ = { - "seq_tile_size": int, - "should_transpose_v": bool, - } +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): + """ + Load block tables from HBM into SRAM + + `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. + In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. + """ + B_P_SIZE = 128 + + # reshape as `(num_tiles, num_blocks_per_tile)` + assert len(block_tables_hbm.shape) == 1 + (num_total_blocks, ) = block_tables_hbm.shape + assert num_blocks_per_tile * num_tiles == num_total_blocks + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + + block_tables_sbuf = nl.zeros( + (ceil_div(num_tiles, + B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + dtype=nl.int32, + ) + for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(num_blocks_per_tile)[None, :] + block_tables_sbuf[i, i_p, i_f] = nl.load( + block_tables_hbm[i_p + i * B_P_SIZE, i_f], + dtype=nl.int32, + mask=(i_p + i * B_P_SIZE < num_tiles), + ) + return block_tables_sbuf + + +@nki.jit +def transform_block_tables_for_indirect_load( + block_tables, + block_size_tiling_factor, + num_head, + head_id, +): + """ + This function does two things: + 1. calculate new `block_tables` for a `head_id` after flattening + `num_block`, `num_head`, and `block_size_tiling_factor` dimensions + 2. transpose the result so that `block_table` for each tile is mapped to + SBUF Partition dimension for vectorized DMA + + Tiling trick to further improve DMA performance: + Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M + blocks of a given `head_id` from HBM, the load `cache[block_tables, + head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not + fully utilize hardware parallelization. The solution is to tile `block_size` + into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * + block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape + `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. + + Note: + We don't further tile D dimension as small DMA size also hurts performance. + """ + B_P_SIZE = 128 + num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( + block_tables.shape) + assert num_tiles_per_partition == B_P_SIZE + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" + + num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE) + block_tables_transposed = nl.ndarray( + ( + num_loads, + par_dim(B_P_SIZE), + num_partitions * num_tiles_per_partition, + ), + dtype=nl.int32, + ) + + # prepare iota ahead of time to avoid repeatedly using Gpsimd + if num_head > 1: + head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) + head_id = nl.transpose( + head_id.broadcast_to((1, num_tiles_per_partition))) + if num_blocks_per_tile > 1: + head_id = head_id.broadcast_to( + (num_tiles_per_partition, num_blocks_per_tile)) + + if block_size_tiling_factor > 1: + broadcast_shape = ( + num_tiles_per_partition, + num_blocks_per_tile, + block_size_tiling_factor, + ) + offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], + dtype=nl.int32).broadcast_to(broadcast_shape) + + for partition_id in nl.affine_range(num_partitions): + block_tables_partition = block_tables[partition_id] + if num_head > 1: + # fuse num_block and num_head dimension + block_tables_partition = block_tables_partition * num_head + head_id + + if block_size_tiling_factor > 1: + # need to apply block size tiling trick + assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE + block_tables_partition = ((block_tables_partition * + block_size_tiling_factor).reshape( + (num_tiles_per_partition, + num_blocks_per_tile, + 1)).broadcast_to(broadcast_shape)) + new_block_tables = block_tables_partition + offset + new_block_tables = new_block_tables.reshape( + (num_tiles_per_partition, B_P_SIZE)) + else: + new_block_tables = block_tables_partition + + # transpose the block table so that it can be used by vector DGE + for i in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = (partition_id * num_tiles_per_partition + + nl.arange(num_tiles_per_partition)[None, :]) + block_tables_transposed[i, i_p, i_f] = nl.transpose( + new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) + return block_tables_transposed + + +@nki.jit +def load_kv_tile_from_cache( + cur_k_tile, + cur_v_tile, + key_cache, + value_cache, + block_tables, + large_k_tile_idx, + num_blocks_per_large_tile, + tiled_block_size, + B_P_SIZE, + B_D_SIZE, +): + """ + Load KV cache and transform Key and Value into layout required by Matmul + + Vectorized DMA Load layout: + Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + + Layout used by attention matmuls: + Key: (par_dim(B_D_SIZE), seqlen_kv) + Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) + equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + """ + # load key cache + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + for load_idx in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + loaded = nl.load(key_cache[block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_k_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) + # Transpose SBUF tensor using PE + for tb_i in nl.affine_range(tiled_block_size): + cur_k_tile[ + :, + nl.ds( + load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, + B_P_SIZE, + ), + ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) + + # load value cache + for load_idx in nl.affine_range(num_loads): + loaded = nl.load(value_cache[block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + cur_v_tile[ + :, + nl.ds( + load_idx * tiled_block_size * B_D_SIZE, + tiled_block_size * B_D_SIZE, + ), + ] = loaded @nki.jit @@ -62,13 +238,13 @@ def _flash_attention_core( o_buffer, l_buffer, m_buffer, - q_tile_idx, kernel_dtype, acc_type, - flash_config: FlashConfig, - use_causal_mask, tile_mask, + use_causal_mask, + q_tile_idx=None, initialize=False, + LARGE_TILE_SZ=2048, B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, @@ -77,19 +253,19 @@ def _flash_attention_core( """ The flash attention core function to calculate self attention between a tile of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF - already. The block size of K and V - is defined in the seq_tile_size of the flash_config. The results are stored - in the following three buffers + The q_local_tile has (B_P_SIZE, B_D_SIZE) + The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will + be split into size B_F_SIZE tiles + + The results are stored in the following three buffers o_buffer: (B_P_SIZE, d) l_buffer: (B_P_SIZE, 1) m_buffer: (B_P_SIZE, 1) + + All IO buffers are in SBUF. """ - LARGE_TILE_SZ = flash_config.seq_tile_size num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - # mask are used to only apply computation to the lower half of the matrix, - # which reduce the arithmetic intensity by half qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) @@ -99,6 +275,8 @@ def _flash_attention_core( k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) if use_causal_mask: + # mask are used to only apply computation to the lower half of the + # matrix, which reduce the arithmetic intensity by up to 50% multiplication_required_selection = (q_tile_idx * B_P_SIZE >= k_i * B_F_SIZE) else: @@ -165,7 +343,9 @@ def _flash_attention_core( REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) p_partial_sum = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), + dtype=acc_type, + ) for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) @@ -194,13 +374,15 @@ def _flash_attention_core( B_F_SIZE=B_F_SIZE, ) - pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), - dtype=np.float32, - buffer=nl.psum) + pv_psum = nl.zeros( + (par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum, + ) for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): pv_psum[:, :] += nl.matmul( p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], - v[k_i, :, :], + v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], transpose_x=True, ) # (128, 128) (p(Br), d) @@ -219,44 +401,16 @@ def _flash_attention_core( @nki.jit -def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): - LARGE_TILE_SZ = config.seq_tile_size +def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): B_P_SIZE = 128 - - if not config.should_transpose_v: - cur_v_tile[v_i, :, :] = nl.load( - v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], - dtype=cur_v_tile.dtype, - ) - return - - if nisa.get_nc_version() == nisa.nc_version.gen3: - cur_v_tile_transposed = nisa.dma_transpose( - v_hbm_tile[:, - nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) - cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, - dtype=cur_v_tile.dtype) - return - - cur_v_tile[v_i, :, :] = nl.load_transpose2d( - v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], - dtype=cur_v_tile.dtype, - ) - - -@nki.jit -def load_block_tables(block_tables_hbm, num_tiles): - (num_blocks, ) = block_tables_hbm.shape - assert num_blocks % num_tiles == 0 - num_blocks_per_tile = num_blocks // num_tiles - block_tables_hbm = block_tables_hbm.reshape( - (num_tiles, num_blocks_per_tile)) - block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32) - return block_tables_buffer - - -def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 + B_D_SIZE = v_hbm_tile.shape[-1] + loaded = nl.load(v_hbm_tile[ + nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), + :, + ]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded @nki.jit @@ -270,24 +424,21 @@ def flash_paged_attention( mask, softmax_scale=None, mixed_precision=True, - config=None, + LARGE_TILE_SZ=2048, return_debug_tensors=False, ): """ Flash PagedAttention Forward Kernel. - - PagedAttention Paper: https://arxiv.org/abs/2309.06180 - - Chunked Prefill Paper: https://arxiv.org/abs/2403.02310 IO tensor layouts: - query: shape (1, n_heads, d, seq_q) - key: shape (1, n_kv_heads, d, seq_k) - value: shape (1, n_kv_heads, seq_v, d) - - key_cache: (num_blocks, block_size, n_kv_heads, d) - - value_cache: (num_blocks, block_size, n_kv_heads, d) + - key_cache: (num_blocks, n_kv_heads, block_size, d) + - value_cache: (num_blocks, n_kv_heads, block_size, d) - block_tables: (num_active_blocks, ) - - mask: (seq_q, num_active_blocks * block_size) + - mask: (seq_q, num_active_blocks * block_size + seq_q) - o: shape (1, n_heads, seq_q, d) - - l_m: shape (1, n_heads, seq_q, 2) - This kernel requires seq_k == seq_v - We use continuous batching by default, so the batch dimension is @@ -306,11 +457,8 @@ def flash_paged_attention( - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types - - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` - with Performance config parameters for flash attention with default - values - seq_tile_size: `default=2048`, size of the kv tile size for attention - computation reduction + - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention + computation reduction GQA support Notes: the spmd kernel for launching kernel should be on kv_heads instead of @@ -322,31 +470,65 @@ def flash_paged_attention( GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] usage: `flash_fwd[b, kv_h](q, k, v, ...)` """ - config = config or FlashConfig() B_F_SIZE = 512 B_P_SIZE = 128 b, h, d, seqlen_q = query.shape B_D_SIZE = d - LARGE_TILE_SZ = config.seq_tile_size n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - num_blocks, block_size, k_h, _ = key_cache.shape + num_blocks, k_h, block_size, _ = key_cache.shape q_h_per_k_h = h // k_h - assert tuple(key_cache.shape) == ( - num_blocks, - block_size, - k_h, - d, - ), "Input shape mismatch!" - assert tuple(value_cache.shape) == ( - num_blocks, - block_size, - k_h, - d, - ), "Input shape mismatch!" assert b == 1, f"invalid batch size {b=}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" + cache_shape = (num_blocks, k_h, block_size, d) + assert (tuple(key_cache.shape) == cache_shape + ), f"{key_cache.shape=} mismatch, expect {cache_shape}" + assert (tuple(value_cache.shape) == cache_shape + ), f"{value_cache.shape=} mismatch, expect {cache_shape}" + assert key is None or tuple(key.shape) == ( + 1, + k_h, + d, + seqlen_q, + ), f"key shape {key.shape} mismatch!" + assert value is None or tuple(value.shape) == ( + 1, + k_h, + seqlen_q, + d, + ), f"value shape {value.shape} mismatch!" + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert ( + LARGE_TILE_SZ % B_F_SIZE == 0 + ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert is_power_of_2( + num_blocks_per_large_tile + ), f"{num_blocks_per_large_tile=} is expected of be power of 2" + if seqlen_q > B_F_SIZE: + MAX_REDUCTION_TILE = 2048 + if seqlen_q // 2 > MAX_REDUCTION_TILE: + assert ( + seqlen_q % MAX_REDUCTION_TILE == 0 + ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" + else: + assert (seqlen_q % B_F_SIZE == 0 + ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + num_large_k_tile = context_kv_len // LARGE_TILE_SZ o = nl.ndarray((b, h, seqlen_q, d), dtype=query.dtype, @@ -373,35 +555,38 @@ def flash_paged_attention( buffer=nl.sbuf, lazy_initialization=True, ) + block_tables_sbuf = load_block_tables( + block_tables_hbm=block_tables, + num_tiles=num_large_k_tile, + num_blocks_per_tile=num_blocks_per_large_tile, + ) - assert ( - nl.program_ndim() == 2 - ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + if num_blocks_per_large_tile < B_P_SIZE: + # we checked num_blocks_per_tile is a power of 2 + assert B_P_SIZE % num_blocks_per_large_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile + # We assume block_size >= block_size_tiling_factor + assert block_size % block_size_tiling_factor == 0 + else: + block_size_tiling_factor = 1 + tiled_block_size = block_size // block_size_tiling_factor - softmax_scale = softmax_scale or (1.0 / (d**0.5)) + # Indirect DMA load must be placed along Partition Dimension + block_tables_sbuf = transform_block_tables_for_indirect_load( + block_tables_sbuf, + block_size_tiling_factor=block_size_tiling_factor, + num_head=k_h, + head_id=head_id, + ) - (num_active_blocks, ) = block_tables.shape - context_kv_len = num_active_blocks * block_size - assert (config.seq_tile_size >= 512 - ), f" seq tile_size {config.seq_tile_size} cannot be less than 512" - assert (context_kv_len % LARGE_TILE_SZ == 0 - ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" - assert ( - LARGE_TILE_SZ % B_P_SIZE == 0 - ), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}" - assert (B_P_SIZE % block_size == 0 - ), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}" - num_large_k_tile = context_kv_len // LARGE_TILE_SZ - num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert block_size % 32 == 0, "block_size is expected to be a multiple of 32" - assert is_power_of_2( - num_blocks_per_large_tile - ), "The number of blocks in each large tile is expected of be power of 2" - assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2" - - block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile) + # Flatten KV cache to be 2D for loading into SBUF + new_cache_shape = ( + num_blocks * k_h * block_size_tiling_factor, + tiled_block_size * d, + ) + key_cache = key_cache.reshape(new_cache_shape) + value_cache = value_cache.reshape(new_cache_shape) # Global Flash Attention accumulators o_buffer = nl.zeros( @@ -411,7 +596,7 @@ def flash_paged_attention( lazy_initialization=True, ) l_buffer = nl.zeros( - (par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h), + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), dtype=acc_type, buffer=nl.sbuf, lazy_initialization=True, @@ -423,50 +608,42 @@ def flash_paged_attention( lazy_initialization=True, ) - for j in nl.sequential_range(0, num_large_k_tile): - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - cur_v_tile = nl.ndarray( - (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): + num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE) + cur_k_tile = nl.ndarray( + (par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype, ) - - for k_i in nl.affine_range(num_blocks_per_large_tile): - loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :, - head_id, :]) - cur_k_tile[:, nl.ds(k_i * - block_size, block_size)] = nl.transpose(loaded) - - load_tile_size = B_P_SIZE - num_blocks_per_partition = load_tile_size // block_size - for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - for block_in_partition in nl.affine_range( - num_blocks_per_partition): - v_i = (partition_idx * num_blocks_per_partition + - block_in_partition) - loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :, - head_id, :]) - cur_v_tile[ - partition_idx, - nl.ds(block_in_partition * block_size, block_size), - :, - ] = loaded_v + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), + dtype=kernel_dtype, + ) + load_kv_tile_from_cache( + cur_k_tile=cur_k_tile, + cur_v_tile=cur_v_tile, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables_sbuf, + large_k_tile_idx=large_k_tile_idx, + num_blocks_per_large_tile=num_blocks_per_large_tile, + tiled_block_size=tiled_block_size, + B_P_SIZE=B_P_SIZE, + B_D_SIZE=B_D_SIZE, + ) for i in nl.affine_range(n_tile_q): - cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=mask.dtype) - for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE), - ]) + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), + ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load( - q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], - dtype=kernel_dtype, - ) # load (d, 128) tile in SBUF + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( @@ -474,15 +651,15 @@ def flash_paged_attention( k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[:, i, i_q_h], + l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], - q_tile_idx=i, kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, - use_causal_mask=False, tile_mask=cur_mask, - initialize=j == 0, + use_causal_mask=False, + q_tile_idx=i, + initialize=large_k_tile_idx == 0, + LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, @@ -492,62 +669,58 @@ def flash_paged_attention( if key is not None and value is not None: B_F_SIZE = min(seqlen_q, B_F_SIZE) LARGE_TILE_SZ = seqlen_q - active_config = FlashConfig( - seq_tile_size=LARGE_TILE_SZ, - should_transpose_v=config.should_transpose_v, - ) cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) cur_v_tile = nl.ndarray( - (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), dtype=kernel_dtype, ) - cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :]) + loaded = nl.load(key[batch_id, head_id, :, :]) + if loaded.dtype != kernel_dtype: + loaded = nl.copy(loaded, dtype=kernel_dtype) + cur_k_tile[:, :] = loaded - load_tile_size = B_P_SIZE v_hbm_tile = value[batch_id, head_id] - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): load_v_tile( v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, - j=0, + large_tile_idx=0, v_i=v_i, - config=active_config, + LARGE_TILE_SZ=LARGE_TILE_SZ, ) for i in nl.affine_range(n_tile_q): - cur_mask = nl.load( - mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(context_kv_len, LARGE_TILE_SZ), - ], - dtype=mask.dtype, - ) + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ]) for i_q_h in nl.affine_range(q_h_per_k_h): q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load( - q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], - dtype=kernel_dtype, - ) # load (d, 128) tile in SBUF + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) q_tile[:, :] = q_sbuf_tile * softmax_scale _flash_attention_core( q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[:, i, i_q_h], + l_buffer=l_buffer[i, i_q_h], m_buffer=m_buffer[i, i_q_h], - q_tile_idx=i, kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=active_config, - use_causal_mask=True, tile_mask=cur_mask, + use_causal_mask=True, + q_tile_idx=i, initialize=False, + LARGE_TILE_SZ=LARGE_TILE_SZ, B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, @@ -559,8 +732,8 @@ def flash_paged_attention( for i_q_h in nl.affine_range(q_h_per_k_h): for i in nl.affine_range(n_tile_q): out = nl.multiply( - o_buffer[i, i_q_h, :, :], - nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]), + o_buffer[i, i_q_h], + nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), dtype=kernel_dtype, ) @@ -589,7 +762,7 @@ def flash_paged_attention( head_id * q_h_per_k_h + i_q_h, nl.ds(i * B_P_SIZE, B_P_SIZE), ], - l_buffer[:, i, i_q_h], + l_buffer[i, i_q_h], ) nl.store( hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], @@ -601,6 +774,49 @@ def flash_paged_attention( return o +def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): + """ + Reorder the mask to make it compatible with the flash attention kernel. + + We vectorize KV cache read to improve DMA utilization. However, the layout + that maximizes DMA bandwidth changes the order tokens are consumed. + + The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, + tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And + each step the engine consumes a column (rather than a row) of B_P_SIZE + tokens. Therefore, the tokens are visited in a strided way. + + To make sure mask matches the order tokens are consumed, we need to properly + transpose mask. + """ + total_query_len, total_seq_len = mask.shape + context_kv_len = total_seq_len - total_query_len + + B_P_SIZE = 128 + assert (LARGE_TILE_SZ + >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" + num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) + tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks + if tiled_block_size > 1: + # Mask reordering is needed when tiled_block_size > 1 + device = mask.device + mask = mask.cpu() + context_mask = mask[:, :context_kv_len] + context_mask = context_mask.view( + total_query_len, + context_kv_len // LARGE_TILE_SZ, + num_tiled_blocks // B_P_SIZE, + B_P_SIZE, + tiled_block_size, + ) + context_mask = context_mask.transpose(3, 4).reshape( + total_query_len, context_kv_len) + new_mask = mask[:, context_kv_len:] + return torch.concat([context_mask, new_mask], dim=1).to(device) + else: + return mask + + def flash_attn_varlen_nkifunc( query, key, @@ -612,13 +828,32 @@ def flash_attn_varlen_nkifunc( n_kv_head=None, head_size=None, LARGE_TILE_SZ=2048, - return_debug_tensors=False, mixed_precision=True, ): - config = FlashConfig( - seq_tile_size=LARGE_TILE_SZ, - should_transpose_v=False, - ) + """ + Compute flash paged attention for variable length sequences. + + This function is a wrapper around the flash attention NKI kernel. It takes + in the following arguments: + - query: (1, n_heads, d, seq_q) + - key: (1, n_kv_heads, d, seq_k) + - value: (1, n_kv_heads, seq_v, d) + - key_cache: (n_blocks, n_kv_heads, block_size, d) + - value_cache: (n_blocks, n_kv_heads, block_size, d) + - block_tables: (n_active_blocks, ) + - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) + + Notes: + - attn_mask must be reordered outside using `reorder_context_mask` + - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) + for better DMA throughput + """ + if n_kv_head is None: + n_kv_head = key_cache.shape[1] + assert key_cache.shape[1] == n_kv_head + if head_size is None: + head_size = key_cache.shape[-1] + kwargs = dict( query=query, key=key, @@ -628,15 +863,9 @@ def flash_attn_varlen_nkifunc( block_tables=block_table, mask=attn_mask, softmax_scale=1.0 / (head_size**0.5), - config=config, mixed_precision=mixed_precision, - return_debug_tensors=return_debug_tensors, + LARGE_TILE_SZ=LARGE_TILE_SZ, ) - _, n_kv_head, _, _ = key.shape - if return_debug_tensors: - o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs) - return o, *debug_tensors - else: - o = flash_paged_attention[1, n_kv_head](**kwargs) - return o + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o