#ifndef DNNL_HELPER_HPP #define DNNL_HELPER_HPP #include #include #include "oneapi/dnnl/dnnl.hpp" namespace { template struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::undef; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; }; template constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; } }; // namespace template class DNNLPrimitiveHelper { public: // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) // A: [M, K], row-major // B: [K, N], column-major // C: [M, N], row-major // bias: [N], row-major, optional // a_scales: [MS] // b_scales: [NS] // Note: Due to the limitation of oneDNN // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is // not supported. template static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, const float* a_scales, const float* b_scales, dnnl_dim_t MS, dnnl_dim_t NS) { auto&& OutputType = get_dnnl_type(); auto&& BiasType = get_dnnl_type(); dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); dnnl::primitive_attr attr; if constexpr (!InputNoScale) { if (MS == 1) { // per-tensor attr.set_scales_mask(DNNL_ARG_SRC, 0); } else { // per-token TORCH_CHECK(false, "per-token quantization is unsupported."); } } if (NS == 1) { // per-tensor attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); } else { // per-channel attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); } dnnl::matmul::primitive_desc matmul_pd; // Create memory descriptors with format_tag::any for the primitive. This // enables the matmul primitive to choose memory layouts for an // optimized primitive implementation, and these layouts may differ from the // ones provided by the user. #ifdef __aarch64__ auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); auto mat_weights_md = dnnl::memory::desc( {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); auto mat_dst_md = dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); if (bias) { dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, mat_weights_md, bias_md, mat_dst_md, attr); } else { matmul_pd = dnnl::matmul::primitive_desc( default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); } #else if (bias) { dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, c_md, attr); } else { matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); } #endif dnnl::matmul matmul(matmul_pd); auto& engine = default_engine(); dnnl::memory a_m(a_md, engine, (void*)a); dnnl::memory b_m(b_md, engine, (void*)b); dnnl::memory c_m(c_md, engine, (void*)c); dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, (void*)a_scales); dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, (void*)b_scales); auto& stream = default_stream(); auto mat_src_mem = a_m; auto mat_weights_mem = b_m; auto mat_dst_mem = c_m; #ifdef __aarch64__ if (matmul_pd.weights_desc() != b_m.get_desc()) { mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); } #endif if constexpr (InputNoScale) { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { {DNNL_ARG_SRC, mat_src_mem}, {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { {DNNL_ARG_SRC, mat_src_mem}, {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } } else { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { {DNNL_ARG_SRC, mat_src_mem}, {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { {DNNL_ARG_SRC, mat_src_mem}, {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } } stream.wait(); } private: static dnnl::engine& default_engine() { static dnnl::engine engine(dnnl::engine::kind::cpu, 0); return engine; } static dnnl::stream& default_stream() { static dnnl::stream stream(default_engine()); return stream; } }; #endif