#include "cpu_attn_vec.hpp" #include "cpu_attn_vec16.hpp" #ifdef CPU_CAPABILITY_AMXBF16 #include "cpu_attn_amx.hpp" #define AMX_DISPATCH(...) \ case cpu_attention::ISA::AMX: { \ using attn_impl = cpu_attention::AttentionImpl; \ return __VA_ARGS__(); \ } #else #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: #endif #ifdef __aarch64__ #include "cpu_attn_neon.hpp" #define NEON_DISPATCH(...) \ case cpu_attention::ISA::NEON: { \ using attn_impl = cpu_attention::AttentionImpl; \ return __VA_ARGS__(); \ } #else #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON: #endif // #ifdef __aarch64__ #define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ case HEAD_DIM: { \ constexpr size_t head_dim = HEAD_DIM; \ return __VA_ARGS__(); \ } #define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ [&] { \ switch (HEAD_DIM) { \ CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ default: { \ TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ std::to_string(HEAD_DIM)); \ } \ } \ }() #define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ [&] { \ switch (ISA_TYPE) { \ AMX_DISPATCH(__VA_ARGS__) \ NEON_DISPATCH(__VA_ARGS__) \ case cpu_attention::ISA::VEC: { \ using attn_impl = \ cpu_attention::AttentionImpl; \ return __VA_ARGS__(); \ } \ case cpu_attention::ISA::VEC16: { \ using attn_impl = \ cpu_attention::AttentionImpl; \ return __VA_ARGS__(); \ } \ default: { \ TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ } \ } \ }() torch::Tensor get_scheduler_metadata( const int64_t num_req, const int64_t num_heads_q, const int64_t num_heads_kv, const int64_t head_dim, const torch::Tensor& seq_lens, at::ScalarType dtype, const torch::Tensor& query_start_loc, const bool casual, const int64_t window_size, const std::string& isa_hint, const bool enable_kv_split) { cpu_attention::ISA isa; if (isa_hint == "amx") { isa = cpu_attention::ISA::AMX; } else if (isa_hint == "vec") { isa = cpu_attention::ISA::VEC; } else if (isa_hint == "vec16") { isa = cpu_attention::ISA::VEC16; } else if (isa_hint == "neon") { isa = cpu_attention::ISA::NEON; } else { TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); } cpu_attention::AttentionScheduler::ScheduleInput input; input.num_reqs = num_req; input.num_heads_q = num_heads_q; input.num_heads_kv = num_heads_kv; input.head_dim = head_dim; input.query_start_loc = query_start_loc.data_ptr(); input.seq_lens = seq_lens.data_ptr(); if (window_size != -1) { input.left_sliding_window_size = window_size - 1; if (casual) { input.right_sliding_window_size = 0; } else { input.right_sliding_window_size = window_size - 1; } } else { input.left_sliding_window_size = -1; if (casual) { input.right_sliding_window_size = 0; } else { input.right_sliding_window_size = -1; } } input.casual = casual; input.isa = isa; input.enable_kv_split = enable_kv_split; TORCH_CHECK(casual, "Only supports casual mask for now."); VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { CPU_ATTN_DISPATCH_IMPL(isa, [&]() { input.elem_size = sizeof(scalar_t); input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); input.output_buffer_elem_size = sizeof(attn_impl::partial_output_buffer_t); input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; input.kv_block_alignment = attn_impl::BlockSizeAlignment; }); }); }); cpu_attention::AttentionScheduler scheduler; torch::Tensor metadata = scheduler.schedule(input); return metadata; } void cpu_attn_reshape_and_cache( const torch::Tensor& key, // [token_num, head_num, head_size] const torch::Tensor& value, // [token_num, head_num, head_size] torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] const torch::Tensor& slot_mapping, const std::string& isa) { TORCH_CHECK_EQ(key.dim(), 3); TORCH_CHECK_EQ(value.dim(), 3); TORCH_CHECK_EQ(key_cache.dim(), 4); TORCH_CHECK_EQ(value_cache.dim(), 4); TORCH_CHECK_EQ(key.stride(2), 1); TORCH_CHECK_EQ(value.stride(2), 1); const int64_t token_num = key.size(0); const int64_t key_token_num_stride = key.stride(0); const int64_t value_token_num_stride = value.stride(0); const int64_t head_num = value.size(1); const int64_t key_head_num_stride = key.stride(1); const int64_t value_head_num_stride = value.stride(1); const int64_t num_blocks = key_cache.size(0); const int64_t num_blocks_stride = key_cache.stride(0); const int64_t cache_head_num_stride = key_cache.stride(1); const int64_t block_size = key_cache.size(2); const int64_t block_size_stride = key_cache.stride(2); const int64_t head_dim = key.size(-1); cpu_attention::ISA isa_tag = [&]() { if (isa == "amx") { return cpu_attention::ISA::AMX; } else if (isa == "vec") { return cpu_attention::ISA::VEC; } else if (isa == "vec16") { return cpu_attention::ISA::VEC16; } else if (isa == "neon") { return cpu_attention::ISA::NEON; } else { TORCH_CHECK(false, "Invalid ISA type: " + isa); } }(); VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { attn_impl::reshape_and_cache( key.data_ptr(), value.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), token_num, key_token_num_stride, value_token_num_stride, head_num, key_head_num_stride, value_head_num_stride, num_blocks, num_blocks_stride, cache_head_num_stride, block_size, block_size_stride); }); }); }); } void cpu_attention_with_kv_cache( const torch::Tensor& query, // [num_tokens, num_heads, head_size] const torch::Tensor& key_cache, // [num_blocks, num_kv_heads, block_size, head_size] const torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& output, // [num_tokens, num_heads, head_size] const torch::Tensor& query_start_loc, // [num_tokens + 1] const torch::Tensor& seq_lens, // [num_tokens] const double scale, const bool causal, const std::optional& alibi_slopes, // [num_heads] const int64_t sliding_window_left, const int64_t sliding_window_right, const torch::Tensor& block_table, // [num_tokens, max_block_num] const double softcap, const torch::Tensor& scheduler_metadata, const std::optional& s_aux // [num_heads] ) { TORCH_CHECK_EQ(query.dim(), 3); TORCH_CHECK_EQ(query.stride(2), 1); TORCH_CHECK_EQ(key_cache.dim(), 4); TORCH_CHECK_EQ(value_cache.dim(), 4); cpu_attention::AttentionInput input; input.metadata = reinterpret_cast( scheduler_metadata.data_ptr()); input.num_tokens = query.size(0); input.num_heads = query.size(1); input.num_kv_heads = key_cache.size(1); input.block_size = key_cache.size(2); input.query = query.data_ptr(); input.query_num_tokens_stride = query.stride(0); input.query_num_heads_stride = query.stride(1); input.cache_num_blocks_stride = key_cache.stride(0); input.cache_num_kv_heads_stride = key_cache.stride(1); input.blt_num_tokens_stride = block_table.stride(0); input.key_cache = key_cache.data_ptr(); input.value_cache = value_cache.data_ptr(); input.output = output.data_ptr(); input.query_start_loc = query_start_loc.data_ptr(); input.seq_lens = seq_lens.data_ptr(); input.block_table = block_table.data_ptr(); input.alibi_slopes = alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; // For now sink must be bf16 input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; input.scale = scale; input.causal = causal; input.sliding_window_left = sliding_window_left; input.sliding_window_right = sliding_window_right; if (input.causal) { // to make boundary calculation easier input.sliding_window_right = 0; } float softcap_fp32 = softcap; input.softcap = softcap_fp32; VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); cpu_attention::AttentionMainLoop mainloop; mainloop(&input); }); }); }); }