From c7753a9809344cd8eefbda6d472cb9ab348f2274 Mon Sep 17 00:00:00 2001 From: nishith-fujitsu <139734058+nishith-fujitsu@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:29:04 +0530 Subject: [PATCH] [Hardware][CPU] Vllm int8 quantization enablement for ARM CPU (#14129) Signed-off-by: nishith-fujitsu --- cmake/cpu_extension.cmake | 28 +++- csrc/cpu/cpu_types_arm.hpp | 267 +++++++++++++++++++++++++++++++++++- csrc/cpu/dnnl_helper.hpp | 58 ++++++-- csrc/cpu/quant.cpp | 21 +-- csrc/cpu/torch_bindings.cpp | 3 +- 5 files changed, 347 insertions(+), 30 deletions(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index fc7291972309a..21fcee66d6030 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -165,17 +165,32 @@ else() endif() # -# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) -# -if (AVX512_FOUND AND NOT AVX512_DISABLED) +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) +# Flag to enable ACL kernels for AARCH64 platforms +if ( VLLM_BUILD_ACL STREQUAL "ON") + set(USE_ACL ON) +else() + set(USE_ACL OFF) +endif() + +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.1 + GIT_TAG v3.8.1 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) + if(USE_ACL) + find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) + if(NOT ARM_COMPUTE_LIBRARY) + message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") + endif() + set(ONEDNN_AARCH64_USE_ACL "ON") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + endif() + set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") set(ONEDNN_BUILD_EXAMPLES "OFF") @@ -264,6 +279,11 @@ elseif(POWER10_FOUND) "csrc/cpu/quant.cpp" ${VLLM_EXT_SRC}) endif() +if (ASIMD_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 65ffe524af738..2251aac45e6fa 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -33,6 +33,8 @@ namespace vec_op { #endif #define FORCE_INLINE __attribute__((always_inline)) inline +// Number of elements in single ASIMD vector of given Datatype +#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0])) namespace { template @@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec { } void save(void* ptr, const int elem_num) const { - int full_blocks = elem_num / 8; - int remainder = elem_num % 8; + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); if (full_blocks > 0) { vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); @@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec { vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + for (int i = 0; i < full_blocks; i++) + vst1q_bf16( + reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + if (remainder > 0) { + bfloat16x8_t temp = reg.val[full_blocks]; + bfloat16_t* base = reinterpret_cast(ptr) + full_blocks * 8; + if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); + } + }; }; struct BF16Vec32 : public Vec { @@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec { : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + for (int i = 0; i < full_blocks; i++) + vst1q_bf16( + reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + if (remainder > 0) { + bfloat16x8_t temp = reg.val[full_blocks]; + bfloat16_t* base = reinterpret_cast(ptr) + full_blocks * 8; + base[0] = vgetq_lane_bf16(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); + } + }; }; #endif @@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + int32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + int32x4x4_t reg; + + explicit INT32Vec16(const void* ptr) { + reg.val[0] = vld1q_s32(reinterpret_cast(ptr)); + reg.val[1] = vld1q_s32(reinterpret_cast(ptr) + 4); + reg.val[2] = vld1q_s32(reinterpret_cast(ptr) + 8); + reg.val[3] = vld1q_s32(reinterpret_cast(ptr) + 12); + } + + void save(int32_t* ptr) const { + vst1q_s32(ptr, reg.val[0]); + vst1q_s32(ptr + 4, reg.val[1]); + vst1q_s32(ptr + 8, reg.val[2]); + vst1q_s32(ptr + 12, reg.val[3]); + }; + + void save(int32_t* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + + for (int i = 0; i < full_blocks; i++) + vst1q_s32( + reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + + if (remainder > 0) { + int32x4_t temp = reg.val[full_blocks]; + int32_t* base = reinterpret_cast(ptr) + full_blocks * 4; + if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3); + } + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec { reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); }; - + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vcvtq_f32_s32(v.reg.val[0]); + reg.val[1] = vcvtq_f32_s32(v.reg.val[1]); + reg.val[2] = vcvtq_f32_s32(v.reg.val[2]); + reg.val[3] = vcvtq_f32_s32(v.reg.val[3]); + }; FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1]), @@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec { vdivq_f32(reg.val[3], b.reg.val[3])})); }; + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(float32x4x4_t( + {vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])), + vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])), + vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])), + vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))})); + }; + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]), + vmaxq_f32(b.reg.val[1], reg.val[1]), + vmaxq_f32(b.reg.val[2], reg.val[2]), + vmaxq_f32(b.reg.val[3], reg.val[3])})); + }; + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + float32x4x4_t temp; + + for (int i = 0; i < full_blocks; i++) + temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]); + + if (remainder > 0) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0), + vgetq_lane_f32(b.reg.val[full_blocks], 0)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0); + } + if (remainder > 1) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1), + vgetq_lane_f32(b.reg.val[full_blocks], 1)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1); + } + if (remainder > 2) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2), + vgetq_lane_f32(b.reg.val[full_blocks], 2)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2); + } + return FP32Vec16(temp); + }; + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({ + vminq_f32(b.reg.val[0], reg.val[0]), + vminq_f32(b.reg.val[1], reg.val[1]), + vminq_f32(b.reg.val[2], reg.val[2]), + vminq_f32(b.reg.val[3], reg.val[3]), + })); + }; + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + float32x4x4_t temp; + for (int i = 0; i < full_blocks; i++) + temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]); + + if (remainder > 0) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0), + vgetq_lane_f32(b.reg.val[full_blocks], 0)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0); + } + if (remainder > 1) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1), + vgetq_lane_f32(b.reg.val[full_blocks], 1)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1); + } + if (remainder > 2) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2), + vgetq_lane_f32(b.reg.val[full_blocks], 2)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2); + } + + return FP32Vec16(temp); + }; + FP32Vec16 abs() const { + return FP32Vec16( + float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]), + vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])})); + } float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec { return answer; }; + float reduce_max() const { + AliasReg ar; + ar.reg = reg; + float max_v = std::numeric_limits::lowest(); + unroll_loop( + [&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); }); + return max_v; + } + + float reduce_min() const { + AliasReg ar; + ar.reg = reg; + float min_v = std::numeric_limits::max(); + unroll_loop( + [&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); }); + return min_v; + } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); @@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec { vst1q_f32(ptr + 8, reg.val[2]); vst1q_f32(ptr + 12, reg.val[3]); }; + + void save(float* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + + for (int i = 0; i < full_blocks; i++) + vst1q_f32( + reinterpret_cast(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + + if (remainder > 0) { + float32x4_t temp = reg.val[full_blocks]; + float* base = reinterpret_cast(ptr) + + full_blocks * NUM_ELEMENTS_REG(reg.val[0]); + if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2); + } + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + int8x16_t reg; + int8_t values[VEC_ELEM_NUM]; + }; + int8x16_t reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + // Convert each 128-bit float32 vector to int32 + int32x4_t part0 = + vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block + int32x4_t part1 = + vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block + int32x4_t part2 = + vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block + int32x4_t part3 = + vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block + + // Narrow each 32-bit vector to 8 bits and combine + int8x8_t lower = + vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1))); + int8x8_t upper = + vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3))); + reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector + } + + void save(int8_t* ptr) const { vst1q_s8(ptr, reg); }; + + void save(int8_t* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg); + int remainder = elem_num % NUM_ELEMENTS_REG(reg); + + for (int i = 0; i < full_blocks; i++) + vst1q_s8(reinterpret_cast(ptr) + NUM_ELEMENTS_REG(reg) * i, reg); + if (remainder > 0) { + int8x16_t temp = reg; + int8_t* base = + reinterpret_cast(ptr) + full_blocks * NUM_ELEMENTS_REG(reg); + if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6); + if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7); + if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8); + if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9); + if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10); + if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11); + if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12); + if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13); + if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14); + } + }; }; template diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp index 8b5011dc065f0..1cb8dc5b25a66 100644 --- a/csrc/cpu/dnnl_helper.hpp +++ b/csrc/cpu/dnnl_helper.hpp @@ -57,6 +57,7 @@ class DNNLPrimitiveHelper { // Note: Due to the limitation of oneDNN // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is // not supported. + template static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, @@ -90,6 +91,27 @@ class DNNLPrimitiveHelper { } dnnl::matmul::primitive_desc matmul_pd; +// Create memory descriptors with format_tag::any for the primitive. This +// enables the matmul primitive to choose memory layouts for an +// optimized primitive implementation, and these layouts may differ from the +// ones provided by the user. +#ifdef __aarch64__ + auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + auto mat_weights_md = dnnl::memory::desc( + {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); + auto mat_dst_md = + dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, + mat_weights_md, bias_md, + mat_dst_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc( + default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); + } +#else if (bias) { dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, @@ -98,6 +120,7 @@ class DNNLPrimitiveHelper { matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); } +#endif dnnl::matmul matmul(matmul_pd); auto& engine = default_engine(); @@ -111,24 +134,34 @@ class DNNLPrimitiveHelper { (void*)b_scales); auto& stream = default_stream(); + + auto mat_src_mem = a_m; + auto mat_weights_mem = b_m; + auto mat_dst_mem = c_m; +#ifdef __aarch64__ + if (matmul_pd.weights_desc() != b_m.get_desc()) { + mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); + dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); + } +#endif if constexpr (InputNoScale) { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } @@ -138,19 +171,19 @@ class DNNLPrimitiveHelper { dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); @@ -170,5 +203,4 @@ class DNNLPrimitiveHelper { return stream; } }; - #endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index f61dbcc948e83..c1f7c64ea2f49 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -36,7 +36,7 @@ struct KernelVecType { using cvt_vec_type = vec_op::FP32Vec16; }; -#ifdef __AVX512F__ +#if defined(__AVX512F__) || defined(__aarch64__) template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, @@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK( - false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK(false, + "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " + "support.") } template @@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK(false, + "dynamic_scaled_int8_quant_impl requires " + "AVX512/powerpc64/AArch64 support.") } template @@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") } template @@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, - "dynamic_quant_epilogue requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, + "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") } #endif } // namespace diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ebfc81f858367..f1738aee980b6 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#ifdef __AVX512F__ +#if defined(__AVX512F__) || defined(__aarch64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"