From 7f829be7d3d734020606fcca520f3c500581beb8 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 12 Nov 2025 09:43:06 +0800 Subject: [PATCH] [CPU] Refactor CPU attention backend (#27954) Signed-off-by: jiang1.li --- .buildkite/release-pipeline.yaml | 2 +- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 28 +- csrc/cpu/attention.cpp | 798 ------- csrc/cpu/cache.cpp | 214 -- csrc/cpu/cpu_attn.cpp | 249 +++ csrc/cpu/cpu_attn_amx.hpp | 511 +++++ csrc/cpu/cpu_attn_impl.hpp | 1977 +++++++++++++++++ csrc/cpu/cpu_attn_macros.h | 63 + csrc/cpu/cpu_attn_vec.hpp | 248 +++ csrc/cpu/cpu_attn_vec16.hpp | 171 ++ csrc/cpu/cpu_types_x86.hpp | 50 +- csrc/cpu/dnnl_helper.cpp | 18 +- csrc/cpu/dnnl_helper.h | 24 - csrc/cpu/scratchpad_manager.cpp | 23 + csrc/cpu/scratchpad_manager.h | 31 + csrc/cpu/shm.cpp | 2 +- csrc/cpu/torch_bindings.cpp | 105 +- docker/Dockerfile.cpu | 4 + docs/getting_started/installation/cpu.md | 2 + .../attention/test_attention_selector.py | 6 +- tests/kernels/attention/test_cpu_attn.py | 575 +++++ tests/kernels/test_onednn.py | 1 - .../models/language/generation/test_common.py | 17 +- .../models/language/pooling/test_embedding.py | 3 +- tests/models/registry.py | 4 +- vllm/_custom_ops.py | 82 + vllm/attention/backends/registry.py | 3 +- vllm/engine/arg_utils.py | 3 - vllm/platforms/cpu.py | 37 +- vllm/utils/__init__.py | 1 - vllm/v1/attention/backends/cpu_attn.py | 985 +++----- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/worker/cpu_model_runner.py | 14 +- 34 files changed, 4354 insertions(+), 1902 deletions(-) delete mode 100644 csrc/cpu/attention.cpp delete mode 100644 csrc/cpu/cache.cpp create mode 100644 csrc/cpu/cpu_attn.cpp create mode 100644 csrc/cpu/cpu_attn_amx.hpp create mode 100644 csrc/cpu/cpu_attn_impl.hpp create mode 100644 csrc/cpu/cpu_attn_macros.h create mode 100644 csrc/cpu/cpu_attn_vec.hpp create mode 100644 csrc/cpu/cpu_attn_vec16.hpp create mode 100644 csrc/cpu/scratchpad_manager.cpp create mode 100644 csrc/cpu/scratchpad_manager.h create mode 100644 tests/kernels/attention/test_cpu_attn.py diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 12f730738b8a5..38c400ba1faf5 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -132,7 +132,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7e0f720feaa71..7479c43977d78 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -49,6 +49,7 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test @@ -116,4 +117,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index dbda19fbcbf20..51447cde0b294 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -15,6 +15,7 @@ endif() # set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) +set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16}) include_directories("${CMAKE_SOURCE_DIR}/csrc") @@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) set(ENABLE_AVX512VNNI OFF) message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") endif() + + find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND) + if (AMXBF16_FOUND OR ENABLE_AMXBF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile") + set(ENABLE_AMXBF16 ON) + add_compile_definitions(-DCPU_CAPABILITY_AMXBF16) + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.") + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size FetchContent_MakeAvailable(oneDNN) + set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE}) add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") target_include_directories( dnnl_ext @@ -305,14 +325,14 @@ endif() # set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" - "csrc/cpu/attention.cpp" - "csrc/cpu/cache.cpp" "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp" - "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/scratchpad_manager.cpp" + "csrc/cpu/torch_bindings.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp deleted file mode 100644 index 82862fea7f2be..0000000000000 --- a/csrc/cpu/attention.cpp +++ /dev/null @@ -1,798 +0,0 @@ -#include "cpu_types.hpp" - -namespace { - -template -struct KernelVecType { - using q_load_vec_type = void; - using q_vec_type = void; - using k_load_vec_type = void; - using k_vec_type = void; - using qk_acc_vec_type = void; - using v_load_vec_type = void; -}; - -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::FP32Vec4; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -}; - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power and s390x architecture-specific vector types - using q_load_vec_type = vec_op::FP32Vec8; - using k_load_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures, including x86 - using q_load_vec_type = vec_op::FP16Vec8; - using k_load_vec_type = vec_op::FP16Vec16; - using v_load_vec_type = vec_op::FP16Vec16; -#endif - using q_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; -}; - -#ifdef __AVX512BF16__ -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::BF16Vec32; - using k_load_vec_type = vec_op::BF16Vec32; - using k_vec_type = vec_op::BF16Vec32; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; -#else - #ifdef __aarch64__ - #ifndef ARM_BF16_SUPPORT - // pass - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif -#endif - -template -FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, - const int capacity) { - T max = data[0]; - for (int i = 1; i < size; ++i) { - max = max >= data[i] ? max : data[i]; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, - const int capacity, - const float alibi_slope, - const int start_index, - const int seq_len) { - data[0] += alibi_slope * (start_index - seq_len + 1); - T max = data[0]; - for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); - data[i] = qk; - max = max >= qk ? max : qk; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, - const int size) { - T max = max_data[0]; - for (int i = 1; i < size; ++i) { - max = max >= max_data[i] ? max : max_data[i]; - } - - T rescaled_sum = 0; - for (int i = 0; i < size; ++i) { - T rescale_factor = std::exp(max_data[i] - max); - rescaled_sum += rescale_factor * sum_data[i]; - sum_data[i] *= rescale_factor; - } - for (int i = 0; i < size; ++i) { - sum_data[i] /= rescaled_sum + 1e-8; - } -} - -template -struct reduceQKBlockKernel { - using q_load_vec_type = typename KernelVecType::q_load_vec_type; - using q_vec_type = typename KernelVecType::q_vec_type; - using k_load_vec_type = typename KernelVecType::k_load_vec_type; - using k_vec_type = typename KernelVecType::k_vec_type; - using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; - - constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; - constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; - constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; - - static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); - static_assert(k_load_vec_type::get_elem_num() % x == 0); - static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, - float* __restrict__ logits, float scale, - const int token_num) { - const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; - - qk_acc_vec_type group_accums[MAX_GROUP_NUM]; - if (token_num == BLOCK_SIZE) { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - - vec_op::unroll_loop( - [k_block, &q_group_vec, &group_accums](int token_group_idx) { - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } else { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - for (int token_group_start = 0; token_group_start < group_num; - token_group_start += UNROLL_GROUP_NUM) { - vec_op::unroll_loop( - [token_group_start, k_block, &q_group_vec, - &group_accums](int token_group_idx) { - token_group_idx += token_group_start; - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } - } - - for (int token_group_idx = 0; token_group_idx < group_num; - ++token_group_idx) { - vec_op::unroll_loop( - [&group_accums, logits, scale, token_group_idx](int token_idx) { - float dot_v = - group_accums[token_group_idx] - .template reduce_sub_sum(token_idx); - logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = - dot_v * scale; - }); - } - } -}; - -template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, - acc_t&& acc) { - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); - static_assert(BLOCK_SIZE == ELEM_NUM); - vec_op::FP32Vec16 prob_vec(prob); - - vec_op::unroll_loop([&](int head_elem_idx) { - v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); - vec_op::FP32Vec16 fp32_v_vec(v_vec); - acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; - }); -} -}; // namespace - -// Paged attention v1 -namespace { -template -struct paged_attention_v1_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - - int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); - - const int parallel_work_item_num = omp_get_max_threads(); - - size_t logits_bytes = - parallel_work_item_num * max_seq_len_padded * sizeof(float); - float* logits = (float*)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] - -#pragma omp parallel for collapse(2) schedule(dynamic, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int seq_len = seq_lens[seq_idx]; - const int* seq_block_table = - block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; - float* __restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_seq_len_padded; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - thread_block_logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - // Compute softmax - if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, seq_len, - block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - seq_len); - } else { - reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - std::free(logits); - } -}; - -#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v1_impl::call( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ - num_heads); - -template -void paged_attention_v1_impl_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes); - -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) - CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) - }); -} - -// Paged attention v2 -namespace { -template -struct paged_attention_v2_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads, const int max_num_partitions) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); - static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); - -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int partition_idx = 0; partition_idx < max_num_partitions; - ++partition_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int start_token_idx = partition_idx * PARTITION_SIZE; - - if (start_token_idx >= seq_len) continue; - - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - const bool no_reduce = (partition_num == 1); - const int token_num = - (std::min(seq_len, start_token_idx + PARTITION_SIZE) - - start_token_idx); - const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int last_block_token_num = - token_num - (block_num - 1) * BLOCK_SIZE; - const int* seq_block_table = block_tables + - max_num_blocks_per_seq * seq_idx + - start_token_idx / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - - float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - std::pair max_and_sum; - if (alibi_slopes) { - max_and_sum = reduceSoftmaxAlibi( - logits, token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, seq_len); - } else { - max_and_sum = - reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); - } - - auto&& [max_logit, exp_sum] = max_and_sum; - - scalar_t* __restrict__ output_buffer = nullptr; - if (!no_reduce) { - auto idx = seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - max_logits[idx] = max_logit; - exp_sums[idx] = exp_sum; - output_buffer = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - output_buffer = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - output_buffer + head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - } - - // Rescale partition softmax and store the factors to exp_sums -#pragma omp parallel for collapse(2) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - reducePartitionSoftmax( - max_logits + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - partition_num); - } - } - - // Reduce values - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); - constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some - // HEAD_SIZE didn't align with 64 bytes - static_assert(HEAD_SIZE % head_elem_num_per_group == 0); - constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float* __restrict__ rescale_factors = exp_sums; -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - const float* __restrict__ seq_head_rescale_factors = - rescale_factors + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - const scalar_t* __restrict__ seq_head_tmp_out = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - group_idx * head_elem_num_per_group; - scalar_t* __restrict__ seq_head_output = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - group_idx * head_elem_num_per_group; - - vec_op::FP32Vec16 acc; - for (int i = 0; i < partition_num; ++i) { - vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); - v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); - vec_op::FP32Vec16 fp32_value(value); - acc = acc + fp32_value * rescale_factor; - } - v_load_vec_type cast_acc(acc); - cast_acc.save(seq_head_output); - } - } - } - } -}; - -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ - max_num_partitions); - -template -void paged_attention_v2_impl_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - int max_num_partitions = exp_sums.size(-1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ - alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) - CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) - }); -} \ No newline at end of file diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp deleted file mode 100644 index 69f6d06e3c967..0000000000000 --- a/csrc/cpu/cache.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include - -#include "cpu_types.hpp" - -#if defined(__x86_64__) - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 -#else - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES -#endif - -namespace { -template -void copy_blocks_cpu_impl(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, - const int layer_num) { - const size_t pair_num = mapping_pairs.size(0); - const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; -#pragma omp parallel for collapse(2) - for (int layer = 0; layer < layer_num; ++layer) { - for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = - element_num_per_block * mapping_pairs[pair][0].item(); - int64_t target_offset = - element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t* source_ptr = key_cache_ptr + source_offset; - scalar_t* target_ptr = key_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - - scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); - source_ptr = value_cache_ptr + source_offset; - target_ptr = value_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - } - } -} - -template -void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x) { - const int block_elem_num = num_heads * head_size * block_size; - -#pragma omp parallel for collapse(2) - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx >= 0) { - int src_key_head_idx = token_idx * key_stride + head_idx * head_size; - int src_value_head_idx = - token_idx * value_stride + head_idx * head_size; - const scalar_t* src_key_head_ptr = key + src_key_head_idx; - const scalar_t* src_value_head_ptr = value + src_value_head_idx; - const int64_t block_index = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - - for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { - const int64_t target_offset = - src_key_idx * block_size + block_offset * x; - for (int i = 0; i < x; ++i) { - target_key_head_ptr[target_offset + i] = - src_key_head_ptr[src_key_idx + i]; - } - } - - for (int src_value_idx = 0; src_value_idx < head_size; - ++src_value_idx) { - const int64_t target_offset = - src_value_idx * block_size + block_offset; - target_value_head_ptr[target_offset] = - src_value_head_ptr[src_value_idx]; - } - } - } - } -} -}; // namespace - -template -void concat_and_cache_mla_cpu_impl( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int num_tokens, // - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size // -) { -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - continue; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, - scalar_t* __restrict__ dst, int src_stride, int dst_stride, - int size, int offset) { - for (int i = 0; i < size; i++) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - dst[dst_idx] = src[src_idx]; - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); - } -} - -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - unsigned num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - - const int element_num_per_block = key_caches[0][0].numel(); - DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); -} - -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, - num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); -} - -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - TORCH_CHECK(kv_cache_dtype != "fp8"); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - VLLM_DISPATCH_FLOATING_TYPES( - kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl) - concat_and_cache_mla_cpu_impl( - kv_c.data_ptr(), k_pe.data_ptr(), - kv_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride, - kv_lora_rank, pe_dim, block_size); - CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl) - }); -} - -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { - TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") -} diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp new file mode 100644 index 0000000000000..50f17c758c148 --- /dev/null +++ b/csrc/cpu/cpu_attn.cpp @@ -0,0 +1,249 @@ +#include "cpu_attn_vec.hpp" +#include "cpu_attn_vec16.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu_attn_amx.hpp" + #define AMX_DISPATCH(...) \ + case cpu_attention::ISA::AMX: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: +#endif + +#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ + case HEAD_DIM: { \ + constexpr size_t head_dim = HEAD_DIM; \ + return __VA_ARGS__(); \ + } + +#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ + [&] { \ + switch (HEAD_DIM) { \ + CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ + std::to_string(HEAD_DIM)); \ + } \ + } \ + }() + +#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ + [&] { \ + switch (ISA_TYPE) { \ + AMX_DISPATCH(__VA_ARGS__) \ + case cpu_attention::ISA::VEC: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + case cpu_attention::ISA::VEC16: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ + } \ + } \ + }() + +torch::Tensor get_scheduler_metadata( + const int64_t num_req, const int64_t num_heads_q, + const int64_t num_heads_kv, const int64_t head_dim, + const torch::Tensor& seq_lens, at::ScalarType dtype, + const torch::Tensor& query_start_loc, const bool casual, + const int64_t window_size, const std::string& isa_hint, + const bool enable_kv_split) { + cpu_attention::ISA isa; + if (isa_hint == "amx") { + isa = cpu_attention::ISA::AMX; + } else if (isa_hint == "vec") { + isa = cpu_attention::ISA::VEC; + } else if (isa_hint == "vec16") { + isa = cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); + } + + cpu_attention::AttentionScheduler::ScheduleInput input; + input.num_reqs = num_req; + input.num_heads_q = num_heads_q; + input.num_heads_kv = num_heads_kv; + input.head_dim = head_dim; + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + if (window_size != -1) { + input.left_sliding_window_size = window_size - 1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = window_size - 1; + } + } else { + input.left_sliding_window_size = -1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = -1; + } + } + input.casual = casual; + input.isa = isa; + input.enable_kv_split = enable_kv_split; + TORCH_CHECK(casual, "Only supports casual mask for now."); + + VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; + }); + }); + }); + + cpu_attention::AttentionScheduler scheduler; + torch::Tensor metadata = scheduler.schedule(input); + return metadata; +} + +void cpu_attn_reshape_and_cache( + const torch::Tensor& key, // [token_num, head_num, head_size] + const torch::Tensor& value, // [token_num, head_num, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& slot_mapping, const std::string& isa) { + TORCH_CHECK_EQ(key.dim(), 3); + TORCH_CHECK_EQ(value.dim(), 3); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + TORCH_CHECK_EQ(key.stride(2), 1); + TORCH_CHECK_EQ(value.stride(2), 1); + + const int64_t token_num = key.size(0); + const int64_t key_token_num_stride = key.stride(0); + const int64_t value_token_num_stride = value.stride(0); + const int64_t head_num = value.size(1); + const int64_t key_head_num_stride = key.stride(1); + const int64_t value_head_num_stride = value.stride(1); + const int64_t num_blocks = key_cache.size(0); + const int64_t num_blocks_stride = key_cache.stride(0); + const int64_t cache_head_num_stride = key_cache.stride(1); + const int64_t block_size = key_cache.size(2); + const int64_t block_size_stride = key_cache.stride(2); + const int64_t head_dim = key.size(-1); + + cpu_attention::ISA isa_tag = [&]() { + if (isa == "amx") { + return cpu_attention::ISA::AMX; + } else if (isa == "vec") { + return cpu_attention::ISA::VEC; + } else if (isa == "vec16") { + return cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Invalid ISA type: " + isa); + } + }(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, + key_token_num_stride, value_token_num_stride, head_num, + key_head_num_stride, value_head_num_stride, num_blocks, + num_blocks_stride, cache_head_num_stride, block_size, + block_size_stride); + }); + }); + }); +} + +void cpu_attention_with_kv_cache( + const torch::Tensor& query, // [num_tokens, num_heads, head_size] + const torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& output, // [num_tokens, num_heads, head_size] + const torch::Tensor& query_start_loc, // [num_tokens + 1] + const torch::Tensor& seq_lens, // [num_tokens] + const double scale, const bool causal, + const std::optional& alibi_slopes, // [num_heads] + const int64_t sliding_window_left, const int64_t sliding_window_right, + const torch::Tensor& block_table, // [num_tokens, max_block_num] + const double softcap, const torch::Tensor& scheduler_metadata, + const std::optional& s_aux // [num_heads] +) { + TORCH_CHECK_EQ(query.dim(), 3); + TORCH_CHECK_EQ(query.stride(2), 1); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + + cpu_attention::AttentionInput input; + input.metadata = reinterpret_cast( + scheduler_metadata.data_ptr()); + input.num_tokens = query.size(0); + input.num_heads = query.size(1); + input.num_kv_heads = key_cache.size(1); + input.block_size = key_cache.size(2); + input.query = query.data_ptr(); + input.query_num_tokens_stride = query.stride(0); + input.query_num_heads_stride = query.stride(1); + input.cache_num_blocks_stride = key_cache.stride(0); + input.cache_num_kv_heads_stride = key_cache.stride(1); + input.blt_num_tokens_stride = block_table.stride(0); + input.key_cache = key_cache.data_ptr(); + input.value_cache = value_cache.data_ptr(); + input.output = output.data_ptr(); + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + input.block_table = block_table.data_ptr(); + input.alibi_slopes = + alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; + // For now sink must be bf16 + input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; + input.scale = scale; + input.causal = causal; + input.sliding_window_left = sliding_window_left; + input.sliding_window_right = sliding_window_right; + if (input.causal) { + // to make boundary calculation easier + input.sliding_window_right = 0; + } + float softcap_fp32 = softcap; + input.softcap = softcap_fp32; + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { + CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); + }); + }); +} diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp new file mode 100644 index 0000000000000..8da458b99119c --- /dev/null +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -0,0 +1,511 @@ +#ifndef CPU_ATTN_AMX_HPP +#define CPU_ATTN_AMX_HPP + +#include "cpu_attn_impl.hpp" + +namespace cpu_attention { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + // k_cache, v_cache are prepacked + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + // logits_buffer, output_buffer are not prepacked + float* __restrict__ c_tile_4 = c_tile; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_tile; + float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = scalar_t; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = scalar_t; + + constexpr static int64_t BlockSizeAlignment = + AMX_TILE_ROW_BYTES / + sizeof(kv_cache_t); // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = 32; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::AMX; + constexpr static bool scale_on_logits = true; + + public: + AttentionImpl() : current_q_head_num_(0) { + // Use all columns in AMX tiles + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + ~AttentionImpl() { _tile_release(); } + + template