/* * Adapted from * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "attention_kernels.cuh" #include "../cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ max_num_partitions); template void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); int q_stride = query.stride(0); int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); const int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. case 32: LAUNCH_PAGED_ATTENTION_V2(32); break; case 64: LAUNCH_PAGED_ATTENTION_V2(64); break; case 80: LAUNCH_PAGED_ATTENTION_V2(80); break; case 96: LAUNCH_PAGED_ATTENTION_V2(96); break; case 112: LAUNCH_PAGED_ATTENTION_V2(112); break; case 120: LAUNCH_PAGED_ATTENTION_V2(120); break; case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; case 192: LAUNCH_PAGED_ATTENTION_V2(192); break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; } } #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ paged_attention_v2_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ if (is_block_sparse) { \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ } else { \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } void paged_attention_v2( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] int64_t num_kv_heads, // [num_heads] double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef MAX #undef MIN #undef DIVIDE_ROUND_UP