From b7c0942b65380ab8c53ecf2657121e1c21150672 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Sat, 9 Aug 2025 01:15:06 -0500 Subject: [PATCH] [ROCm][Misc] Rename the context_len to seq_len in ROCm custom paged attention kernel (#22097) Signed-off-by: charlifu --- csrc/rocm/attention.cu | 179 +++++++++++++++++------------------ csrc/rocm/ops.h | 4 +- csrc/rocm/torch_bindings.cpp | 4 +- 3 files changed, 91 insertions(+), 96 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 65cb1c1d1478d..e3a0e15f5304f 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const auto max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // partition_size; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens // across 4 rows x 4 tokens per lane - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // tokens const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( if constexpr (ALIBI_ENABLED) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; - const int alibi_offset = local_token_idx - context_len + 1; + const int alibi_offset = local_token_idx - seq_len + 1; for (int i = 0; i < 4; i++) { d_out[token_depth][i] += alibi_slope * (alibi_offset + i); } @@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 4; i++) { - const float tmp = (local_token_idx + i < context_len) - ? d_out[token_depth][i] - : -FLT_MAX; + const float tmp = + (local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 4; i++) { - const float tmp = (local_token_idx + i < context_len) + const float tmp = (local_token_idx + i < seq_len) ? __expf(d_out[token_depth][i] - qk_max) : 0.0f; d_out[token_depth][i] = tmp; @@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const auto partition_size = blockDim.x; const auto max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_start_token_idx = partition_idx * partition_size; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } // every 4 lanes fetch 4 different qheads @@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; - if (warp_start_token_idx >= context_len) { // warp out of context + if (warp_start_token_idx >= seq_len) { // warp out of context #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; @@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( } } else { // warp within context - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; // token id within partition @@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int global_token_idx = partition_start_token_idx + local_token_idx; // fetch block number for k - const int block_idx = (global_token_idx < context_len) + const int block_idx = (global_token_idx < seq_len) ? global_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; // fetch k physical block number // int32 physical_block_number leads to overflow when multiplied with @@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + (vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } @@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int lane4_token_idx = 4 * (global_token_idx >> 2); if constexpr (ALIBI_ENABLED) { - const int alibi_offset = lane4_token_idx - context_len + 1; + const int alibi_offset = lane4_token_idx - seq_len + 1; for (int h = 0; h < QHLOOP; h++) { for (int i = 0; i < 4; i++) { d_out[h][i] += alibi_slope[h] * (alibi_offset + i); @@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; for (int i = 0; i < 4; i++) { - qk_max[h] = (lane4_token_idx + i < context_len) + qk_max[h] = (lane4_token_idx + i < seq_len) ? fmaxf(qk_max[h], d_out[h][i]) : qk_max[h]; } @@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; for (int i = 0; i < 4; i++) { - d_out[h][i] = (lane4_token_idx + i < context_len) + d_out[h][i] = (lane4_token_idx + i < seq_len) ? __expf(d_out[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += d_out[h][i]; @@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( } } - if (warp_start_token_idx >= context_len) { // warp out of context + if (warp_start_token_idx >= seq_len) { // warp out of context for (int qh = 0; qh < QHLOOP; qh++) { for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; @@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const auto warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; // length of a seq + const int seq_len = seq_lens[seq_idx]; // length of a seq const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( } } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( vblock_depth * BLOCK_SIZE; const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + 2 * i < context_len) - ? dout[token_depth][i] - : -FLT_MAX; + const float tmp = + (local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + 2 * i < context_len) + const float tmp = (local_token_idx + 2 * i < seq_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; dout[token_depth][i] = tmp; @@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int max_num_partitions = gridDim.y; - const int context_len = context_lens[seq_idx]; // length of a seq + const int seq_len = seq_lens[seq_idx]; // length of a seq const int partition_start_token_idx = partition_idx * T_PAR_SIZE; // exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { + if (partition_start_token_idx >= seq_len) { return; } @@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( } } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int last_seq_block = num_seq_blocks - 1; const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; @@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) + const int kblock_idx = (kglobal_token_idx < seq_len) ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } @@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; - const int vblock_idx = (vglobal_token_idx < context_len) + const int vblock_idx = (vglobal_token_idx < seq_len) ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + : last_seq_block; vphysical_block_number[vtoken_depth][vblock_depth] = block_table_seq[vblock_idx]; } @@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { const float tmp = - (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + (local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX; qk_max = fmaxf(qk_max, tmp); } } @@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i = 0; i < 8; i++) { - const float tmp = (local_token_idx + i < context_len) + const float tmp = (local_token_idx + i < seq_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; dout[token_depth][i] = tmp; @@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; @@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( return; } - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int warpid = threadIdx.x / WARP_SIZE; __shared__ float shared_global_exp_sum; @@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] @@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE @@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ max_ctx_blocks, k_scale_ptr, v_scale_ptr); @@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ max_ctx_blocks, k_scale_ptr, v_scale_ptr); -#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ - paged_attention_ll4mi_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ - fp8_out_scale_ptr); +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr); template & query_start_loc, int max_context_len, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + const std::optional& query_start_loc, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); @@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher( KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); // NOTE: fp8_out_scale is optional. @@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher( : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE); // partition size is fixed at 256 since both mfma4 and mfma16 kernels support // it mfma4 kernel also supports partition size 512 constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); @@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - const std::optional& query_start_loc, int max_context_len, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + const std::optional& query_start_loc, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_seqs = block_tables.size(0); @@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi( KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); @@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi( const auto fp8_out_scale_ptr = nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE); constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); @@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi( paged_attention_custom_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ + num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ + max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } else { \ paged_attention_custom_launcher_navi< \ T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); \ + num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ + max_seq_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ @@ -3502,9 +3497,9 @@ void paged_attention( int64_t num_kv_heads, double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] const std::optional& query_start_loc, // [num_seqs] - int64_t block_size, int64_t max_context_len, + int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index e538197dbcb04..34dcc9401aae8 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -15,8 +15,8 @@ void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, + torch::Tensor& block_tables, torch::Tensor& seq_lens, const std::optional& query_start_loc, int64_t block_size, - int64_t max_context_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 34575477bcc94..66bdc448da3ca 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads," " float scale, Tensor block_tables," - " Tensor context_lens," + " Tensor seq_lens," " Tensor? query_start_loc," " int block_size," - " int max_context_len," + " int max_seq_len," " Tensor? alibi_slopes," " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale,"