#include "cpu_types.hpp" #include "scratchpad_manager.h" #include "utils.hpp" #ifdef CPU_CAPABILITY_AMXBF16 #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" #endif #include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp" #define VLLM_DISPATCH_CASE_16B_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) #define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__)) template void print_logits(const char* name, T* ptr, int32_t row, int32_t col, int32_t stride) { std::stringstream ss; ss << std::fixed << std::setprecision(5) << name << ": [\n"; auto* curr_logits_buffer = ptr; for (int32_t m = 0; m < row; ++m) { for (int32_t n = 0; n < col; ++n) { ss << curr_logits_buffer[n] << ", "; } ss << "\n"; curr_logits_buffer += stride; } ss << "]\n"; std::printf("%s", ss.str().c_str()); } namespace { using cpu_utils::ISA; using cpu_utils::VecTypeTrait; template class Dequantizer4b { public: constexpr static int32_t pack_num = 32 / 4; using scalar_vec_t = typename VecTypeTrait::vec_t; public: static void dequant(int32_t* __restrict__ q_weight, scalar_t* __restrict__ weight, scalar_t* __restrict__ scales, int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx, const int64_t scales_stride, const int64_t zeros_stride, const int32_t k_size, const int32_t group_size) { vec_op::FP32Vec16 lut; if constexpr (has_zp) { // AWQ alignas(64) static const float LUT[16] = { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f}; lut = vec_op::FP32Vec16(LUT); } else { // GPTQ alignas(64) static const float LUT[16] = { -8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; lut = vec_op::FP32Vec16(LUT); } // per 64-bits elem contains 16 output channels int64_t* __restrict__ curr_q_weight = reinterpret_cast(q_weight); int64_t* __restrict__ curr_zeros = reinterpret_cast(zeros); scalar_t* __restrict__ curr_weight = weight; scalar_t* __restrict__ curr_scale = scales; vec_op::FP32Vec16 scale_0; vec_op::FP32Vec16 scale_1; vec_op::FP32Vec16 zero_0; vec_op::FP32Vec16 zero_1; int32_t group_counter = 0; for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) { int64_t qwb_0 = *curr_q_weight; int64_t qwb_1 = *(curr_q_weight + 1); vec_op::FP32Vec16 wb_0(qwb_0, lut); vec_op::FP32Vec16 wb_1(qwb_1, lut); if constexpr (!use_desc_act) { if (group_counter == 0) { scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale)); scale_1 = vec_op::FP32Vec16(scale_0); curr_scale += scales_stride; if constexpr (has_zp) { zero_0 = vec_op::FP32Vec16(*curr_zeros, lut); zero_1 = vec_op::FP32Vec16(zero_0); curr_zeros += zeros_stride / 2; } } } else { int32_t g_idx_0 = g_idx[k_idx]; int32_t g_idx_1 = g_idx[k_idx + 1]; scale_0 = vec_op::FP32Vec16( scalar_vec_t(curr_scale + g_idx_0 * scales_stride)); scale_1 = vec_op::FP32Vec16( scalar_vec_t(curr_scale + g_idx_1 * scales_stride)); if constexpr (has_zp) { zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2), lut); zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2), lut); } } if constexpr (has_zp) { wb_0 = wb_0 - zero_0; wb_1 = wb_1 - zero_1; } wb_0 = wb_0 * scale_0; wb_1 = wb_1 * scale_1; scalar_vec_t output_vec_0(wb_0); scalar_vec_t output_vec_1(wb_1); // AMX needs to interlave K elements to pack as 32 bits if constexpr (isa == ISA::AMX) { vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight); } else { output_vec_0.save(curr_weight); output_vec_1.save(curr_weight + 16); } // update curr_q_weight += 2; curr_weight += 32; if constexpr (!use_desc_act) { group_counter += 2; if (group_counter == group_size) { group_counter = 0; } } } } }; }; // namespace template void cpu_gemm_wna16_impl( scalar_t* __restrict__ input, int32_t* __restrict__ q_weight, scalar_t* __restrict__ output, scalar_t* __restrict__ scales, int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx, scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size, const int32_t k_size, const int64_t input_stride, const int64_t output_stride, const int64_t scales_group_stride, const int64_t zeros_group_stride, const int32_t group_num, const int32_t group_size, const int64_t pack_factor) { constexpr int32_t gemm_n_tile_size = gemm_t::NSize; constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; constexpr int32_t n_block_size = 16; static_assert(gemm_n_tile_size % n_block_size == 0); const int32_t thread_num = omp_get_max_threads(); // a simple schedule policy, just to hold more B tiles in L2 and make sure // each thread has tasks const int32_t n_partition_size = [&]() { const int64_t cache_size = cpu_utils::get_l2_size(); int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t)); int64_t ps_thread_limit = n_size / thread_num; ps_cache_limit = std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size, (int64_t)gemm_n_tile_size); ps_thread_limit = std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size, (int64_t)gemm_n_tile_size); return std::min(ps_cache_limit, ps_thread_limit); }(); const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size; // get buffer size const int64_t b_buffer_size = (((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64); const int64_t c_buffer_size = (((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64); const int64_t b_buffer_offset = 0; const int64_t c_buffer_offset = b_buffer_size; const int64_t buffer_size = b_buffer_size + c_buffer_size; DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size * thread_num); alignas(64) cpu_utils::Counter counter; cpu_utils::Counter* counter_ptr = &counter; #pragma omp parallel for schedule(static, 1) for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) { scalar_t* __restrict__ b_buffer = nullptr; float* __restrict__ c_buffer = nullptr; { uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager() ->get_data() + thread_id * buffer_size; b_buffer = reinterpret_cast(buffer_ptr + b_buffer_offset); c_buffer = reinterpret_cast(buffer_ptr + c_buffer_offset); } const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size; const int64_t b_buffer_block_stride = n_block_size * k_size; const int32_t zeros_block_stride = n_block_size / pack_factor; gemm_t gemm; for (;;) { int32_t task_id = counter_ptr->acquire_counter(); if (task_id >= task_num) { break; } const int32_t n_start_idx = task_id * n_partition_size; const int32_t n_block_start_idx = n_start_idx / n_block_size; const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx); const int32_t n_block_num = n_num / n_block_size; // std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n", // thread_id, task_id, n_start_idx, n_num); // dequant weight { int32_t* __restrict__ curr_q_weight = q_weight + n_block_start_idx * q_weight_block_stride; scalar_t* __restrict__ curr_b_buffer = b_buffer; scalar_t* __restrict__ curr_scales = scales + n_start_idx; int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor; for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) { dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales, curr_zeros, g_idx, scales_group_stride, zeros_group_stride, k_size, group_size); // if (block_idx == 0 && n_start_idx == 0) { // print_logits("depacked weight", curr_b_buffer, k_size, // n_block_size, n_block_size); // } // update curr_q_weight += q_weight_block_stride; curr_b_buffer += b_buffer_block_stride; curr_scales += n_block_size; curr_zeros += zeros_block_stride; } } // compute loop { const int32_t n_tile_num = n_num / gemm_n_tile_size; scalar_t* __restrict__ curr_input = input; scalar_t* __restrict__ init_bias = bias; if (bias != nullptr) { init_bias += n_start_idx; } scalar_t* __restrict__ init_output = output + n_start_idx; for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) { const int32_t curr_m_size = std::min(gemm_m_tile_size, m_size - m_idx); scalar_t* __restrict__ curr_b_buffer = b_buffer; scalar_t* __restrict__ curr_bias = init_bias; scalar_t* __restrict__ curr_output = init_output; for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) { gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size, input_stride, b_buffer_block_stride, gemm_n_tile_size, false); if (bias != nullptr) { cpu_micro_gemm::bias_epilogue( c_buffer, curr_output, curr_bias, curr_m_size, gemm_n_tile_size, output_stride); curr_bias += gemm_n_tile_size; } else { cpu_micro_gemm::default_epilogue( c_buffer, curr_output, curr_m_size, gemm_n_tile_size, output_stride); } curr_b_buffer += b_buffer_block_stride * (gemm_n_tile_size / n_block_size); curr_output += gemm_n_tile_size; } curr_input += gemm_m_tile_size * input_stride; init_output += gemm_m_tile_size * output_stride; } } } } } void cpu_gemm_wna16( const torch::Tensor& input, // [M, K] const torch::Tensor& q_weight, // [N / 16, K * 16 / pack_factor], packed as int32 torch::Tensor& output, // [M, N] const torch::Tensor& scales, // [group_num, N] const std::optional& zeros, // [group_num, N / pack_factor], packed as int32 const std::optional& g_idx, // [K] const std::optional& bias, // [N] const int64_t pack_factor, const std::string& isa_hint) { using cpu_utils::ISA; TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits const int32_t a_m_size = input.size(0); const int32_t a_k_size = input.size(1); const int64_t a_m_stride = input.stride(0); const int32_t b_n_size = q_weight.size(0) * 16; TORCH_CHECK_EQ(a_k_size % 32, 0); TORCH_CHECK_EQ(b_n_size % 32, 0); const int32_t group_num = scales.size(0); const int32_t group_size = a_k_size / group_num; TORCH_CHECK_EQ(group_size % 2, 0); const int64_t scales_group_stride = scales.stride(0); const int64_t output_m_stride = output.stride(0); bool has_zp = zeros.has_value(); bool use_desc_act = g_idx.has_value(); TORCH_CHECK(!(has_zp && use_desc_act)); ISA isa = [&]() { if (isa_hint == "amx") { return ISA::AMX; } else if (isa_hint == "vec") { return ISA::VEC; } else { TORCH_CHECK(false, "unsupported isa hint: " + isa_hint); } }(); int32_t* zeros_ptr = has_zp ? zeros->data_ptr() : nullptr; const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0; int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr() : nullptr; VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() { if (isa == ISA::AMX) { using gemm_t = cpu_micro_gemm::MicroGemm; if (has_zp) { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } if (use_desc_act) { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } else { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } } else if (isa == ISA::VEC) { using gemm_t = cpu_micro_gemm::MicroGemm; if (has_zp) { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } if (use_desc_act) { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } else { using dequantizer_t = Dequantizer4b; cpu_gemm_wna16_impl( input.data_ptr(), q_weight.data_ptr(), output.data_ptr(), scales.data_ptr(), zeros_ptr, g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, scales_group_stride, zeros_group_stride, group_num, group_size, pack_factor); return; } } }); }