mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 01:47:52 +08:00
[Hardware][CPU] Vllm int8 quantization enablement for ARM CPU (#14129)
Signed-off-by: nishith-fujitsu <nishith.jaiswal@fujitsu.com>
This commit is contained in:
parent
4b9a9435bb
commit
c7753a9809
@ -165,17 +165,32 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||||
#
|
# Flag to enable ACL kernels for AARCH64 platforms
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||||
|
set(USE_ACL ON)
|
||||||
|
else()
|
||||||
|
set(USE_ACL OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
oneDNN
|
oneDNN
|
||||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||||
GIT_TAG v3.7.1
|
GIT_TAG v3.8.1
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
GIT_SHALLOW TRUE
|
GIT_SHALLOW TRUE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(USE_ACL)
|
||||||
|
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||||
|
if(NOT ARM_COMPUTE_LIBRARY)
|
||||||
|
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||||
|
endif()
|
||||||
|
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||||
|
endif()
|
||||||
|
|
||||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||||
set(ONEDNN_BUILD_DOC "OFF")
|
set(ONEDNN_BUILD_DOC "OFF")
|
||||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||||
@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
|
|||||||
"csrc/cpu/quant.cpp"
|
"csrc/cpu/quant.cpp"
|
||||||
${VLLM_EXT_SRC})
|
${VLLM_EXT_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
if (ASIMD_FOUND)
|
||||||
|
set(VLLM_EXT_SRC
|
||||||
|
"csrc/cpu/quant.cpp"
|
||||||
|
${VLLM_EXT_SRC})
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,8 @@ namespace vec_op {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
// Number of elements in single ASIMD vector of given Datatype
|
||||||
|
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename T, T... indexes, typename F>
|
template <typename T, T... indexes, typename F>
|
||||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void save(void* ptr, const int elem_num) const {
|
void save(void* ptr, const int elem_num) const {
|
||||||
int full_blocks = elem_num / 8;
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
int remainder = elem_num % 8;
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
|
||||||
if (full_blocks > 0) {
|
if (full_blocks > 0) {
|
||||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||||
|
|
||||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||||
|
void save(void* ptr, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
vst1q_bf16(
|
||||||
|
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||||
|
reg.val[i]);
|
||||||
|
if (remainder > 0) {
|
||||||
|
bfloat16x8_t temp = reg.val[full_blocks];
|
||||||
|
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||||
|
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||||
|
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||||
|
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||||
|
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||||
|
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||||
|
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||||
|
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||||
|
|
||||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||||
|
void save(void* ptr, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
vst1q_bf16(
|
||||||
|
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||||
|
reg.val[i]);
|
||||||
|
if (remainder > 0) {
|
||||||
|
bfloat16x8_t temp = reg.val[full_blocks];
|
||||||
|
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||||
|
base[0] = vgetq_lane_bf16(temp, 0);
|
||||||
|
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||||
|
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||||
|
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||||
|
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||||
|
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||||
|
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
union AliasReg {
|
||||||
|
int32x4x4_t reg;
|
||||||
|
int32_t values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
int32x4x4_t reg;
|
||||||
|
|
||||||
|
explicit INT32Vec16(const void* ptr) {
|
||||||
|
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||||
|
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||||
|
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||||
|
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(int32_t* ptr) const {
|
||||||
|
vst1q_s32(ptr, reg.val[0]);
|
||||||
|
vst1q_s32(ptr + 4, reg.val[1]);
|
||||||
|
vst1q_s32(ptr + 8, reg.val[2]);
|
||||||
|
vst1q_s32(ptr + 12, reg.val[3]);
|
||||||
|
};
|
||||||
|
|
||||||
|
void save(int32_t* ptr, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
vst1q_s32(
|
||||||
|
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||||
|
reg.val[i]);
|
||||||
|
|
||||||
|
if (remainder > 0) {
|
||||||
|
int32x4_t temp = reg.val[full_blocks];
|
||||||
|
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||||
|
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||||
|
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||||
|
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||||
|
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||||
constexpr static int VEC_ELEM_NUM = 16;
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
union AliasReg {
|
union AliasReg {
|
||||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||||
};
|
};
|
||||||
|
explicit FP32Vec16(const INT32Vec16& v) {
|
||||||
|
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||||
|
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||||
|
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||||
|
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||||
|
};
|
||||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||||
|
return FP32Vec16(float32x4x4_t(
|
||||||
|
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||||
|
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||||
|
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||||
|
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||||
|
};
|
||||||
|
|
||||||
|
FP32Vec16 max(const FP32Vec16& b) const {
|
||||||
|
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||||
|
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||||
|
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||||
|
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||||
|
};
|
||||||
|
|
||||||
|
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
float32x4x4_t temp;
|
||||||
|
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||||
|
|
||||||
|
if (remainder > 0) {
|
||||||
|
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||||
|
}
|
||||||
|
if (remainder > 1) {
|
||||||
|
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||||
|
}
|
||||||
|
if (remainder > 2) {
|
||||||
|
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||||
|
}
|
||||||
|
return FP32Vec16(temp);
|
||||||
|
};
|
||||||
|
|
||||||
|
FP32Vec16 min(const FP32Vec16& b) const {
|
||||||
|
return FP32Vec16(float32x4x4_t({
|
||||||
|
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||||
|
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||||
|
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||||
|
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
float32x4x4_t temp;
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||||
|
|
||||||
|
if (remainder > 0) {
|
||||||
|
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||||
|
}
|
||||||
|
if (remainder > 1) {
|
||||||
|
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||||
|
}
|
||||||
|
if (remainder > 2) {
|
||||||
|
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||||
|
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||||
|
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return FP32Vec16(temp);
|
||||||
|
};
|
||||||
|
FP32Vec16 abs() const {
|
||||||
|
return FP32Vec16(
|
||||||
|
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||||
|
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||||
|
}
|
||||||
float reduce_sum() const {
|
float reduce_sum() const {
|
||||||
AliasReg ar;
|
AliasReg ar;
|
||||||
ar.reg = reg;
|
ar.reg = reg;
|
||||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
return answer;
|
return answer;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
float reduce_max() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
float max_v = std::numeric_limits<float>::lowest();
|
||||||
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||||
|
return max_v;
|
||||||
|
}
|
||||||
|
|
||||||
|
float reduce_min() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
float min_v = std::numeric_limits<float>::max();
|
||||||
|
unroll_loop<int, VEC_ELEM_NUM>(
|
||||||
|
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||||
|
return min_v;
|
||||||
|
}
|
||||||
|
|
||||||
template <int group_size>
|
template <int group_size>
|
||||||
float reduce_sub_sum(int idx) {
|
float reduce_sub_sum(int idx) {
|
||||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
vst1q_f32(ptr + 8, reg.val[2]);
|
vst1q_f32(ptr + 8, reg.val[2]);
|
||||||
vst1q_f32(ptr + 12, reg.val[3]);
|
vst1q_f32(ptr + 12, reg.val[3]);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void save(float* ptr, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
vst1q_f32(
|
||||||
|
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||||
|
reg.val[i]);
|
||||||
|
|
||||||
|
if (remainder > 0) {
|
||||||
|
float32x4_t temp = reg.val[full_blocks];
|
||||||
|
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||||
|
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||||
|
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||||
|
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||||
|
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
union AliasReg {
|
||||||
|
int8x16_t reg;
|
||||||
|
int8_t values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
int8x16_t reg;
|
||||||
|
|
||||||
|
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||||
|
// Convert each 128-bit float32 vector to int32
|
||||||
|
int32x4_t part0 =
|
||||||
|
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||||
|
int32x4_t part1 =
|
||||||
|
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||||
|
int32x4_t part2 =
|
||||||
|
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||||
|
int32x4_t part3 =
|
||||||
|
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||||
|
|
||||||
|
// Narrow each 32-bit vector to 8 bits and combine
|
||||||
|
int8x8_t lower =
|
||||||
|
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||||
|
int8x8_t upper =
|
||||||
|
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||||
|
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||||
|
|
||||||
|
void save(int8_t* ptr, const int elem_num) const {
|
||||||
|
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||||
|
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||||
|
|
||||||
|
for (int i = 0; i < full_blocks; i++)
|
||||||
|
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||||
|
if (remainder > 0) {
|
||||||
|
int8x16_t temp = reg;
|
||||||
|
int8_t* base =
|
||||||
|
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||||
|
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||||
|
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||||
|
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||||
|
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||||
|
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||||
|
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||||
|
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||||
|
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||||
|
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||||
|
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||||
|
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||||
|
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||||
|
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||||
|
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||||
|
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
|||||||
// Note: Due to the limitation of oneDNN
|
// Note: Due to the limitation of oneDNN
|
||||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||||
// not supported.
|
// not supported.
|
||||||
|
|
||||||
template <typename OutputT, typename BiasT>
|
template <typename OutputT, typename BiasT>
|
||||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dnnl::matmul::primitive_desc matmul_pd;
|
dnnl::matmul::primitive_desc matmul_pd;
|
||||||
|
// Create memory descriptors with format_tag::any for the primitive. This
|
||||||
|
// enables the matmul primitive to choose memory layouts for an
|
||||||
|
// optimized primitive implementation, and these layouts may differ from the
|
||||||
|
// ones provided by the user.
|
||||||
|
#ifdef __aarch64__
|
||||||
|
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||||
|
dnnl::memory::format_tag::any);
|
||||||
|
auto mat_weights_md = dnnl::memory::desc(
|
||||||
|
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||||
|
auto mat_dst_md =
|
||||||
|
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||||
|
if (bias) {
|
||||||
|
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||||
|
mat_weights_md, bias_md,
|
||||||
|
mat_dst_md, attr);
|
||||||
|
} else {
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(
|
||||||
|
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||||
|
}
|
||||||
|
#else
|
||||||
if (bias) {
|
if (bias) {
|
||||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
|||||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||||
c_md, attr);
|
c_md, attr);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
dnnl::matmul matmul(matmul_pd);
|
dnnl::matmul matmul(matmul_pd);
|
||||||
|
|
||||||
auto& engine = default_engine();
|
auto& engine = default_engine();
|
||||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
|||||||
(void*)b_scales);
|
(void*)b_scales);
|
||||||
|
|
||||||
auto& stream = default_stream();
|
auto& stream = default_stream();
|
||||||
|
|
||||||
|
auto mat_src_mem = a_m;
|
||||||
|
auto mat_weights_mem = b_m;
|
||||||
|
auto mat_dst_mem = c_m;
|
||||||
|
#ifdef __aarch64__
|
||||||
|
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||||
|
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||||
|
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
if constexpr (InputNoScale) {
|
if constexpr (InputNoScale) {
|
||||||
if (bias) {
|
if (bias) {
|
||||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||||
matmul.execute(
|
matmul.execute(
|
||||||
stream, {
|
stream, {
|
||||||
{DNNL_ARG_SRC, a_m},
|
{DNNL_ARG_SRC, mat_src_mem},
|
||||||
{DNNL_ARG_WEIGHTS, b_m},
|
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||||
{DNNL_ARG_BIAS, bias_m},
|
{DNNL_ARG_BIAS, bias_m},
|
||||||
{DNNL_ARG_DST, c_m},
|
{DNNL_ARG_DST, mat_dst_mem},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
matmul.execute(
|
matmul.execute(
|
||||||
stream, {
|
stream, {
|
||||||
{DNNL_ARG_SRC, a_m},
|
{DNNL_ARG_SRC, mat_src_mem},
|
||||||
{DNNL_ARG_WEIGHTS, b_m},
|
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||||
{DNNL_ARG_DST, c_m},
|
{DNNL_ARG_DST, mat_dst_mem},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
|||||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||||
matmul.execute(
|
matmul.execute(
|
||||||
stream, {
|
stream, {
|
||||||
{DNNL_ARG_SRC, a_m},
|
{DNNL_ARG_SRC, mat_src_mem},
|
||||||
{DNNL_ARG_WEIGHTS, b_m},
|
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||||
{DNNL_ARG_BIAS, bias_m},
|
{DNNL_ARG_BIAS, bias_m},
|
||||||
{DNNL_ARG_DST, c_m},
|
{DNNL_ARG_DST, mat_dst_mem},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
matmul.execute(
|
matmul.execute(
|
||||||
stream, {
|
stream, {
|
||||||
{DNNL_ARG_SRC, a_m},
|
{DNNL_ARG_SRC, mat_src_mem},
|
||||||
{DNNL_ARG_WEIGHTS, b_m},
|
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||||
{DNNL_ARG_DST, c_m},
|
{DNNL_ARG_DST, mat_dst_mem},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||||
});
|
});
|
||||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
|||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
|||||||
using cvt_vec_type = vec_op::FP32Vec16;
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||||
template <bool AZP, typename scalar_t>
|
template <bool AZP, typename scalar_t>
|
||||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||||
const float* scale, const int32_t* azp,
|
const float* scale, const int32_t* azp,
|
||||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
|||||||
const float* scale, const int32_t* azp,
|
const float* scale, const int32_t* azp,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(false,
|
||||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||||
|
"support.")
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
|||||||
float* scale, int32_t* azp,
|
float* scale, int32_t* azp,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(false,
|
||||||
false,
|
"dynamic_scaled_int8_quant_impl requires "
|
||||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
"AVX512/powerpc64/AArch64 support.")
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool PerChannel, typename scalar_t>
|
template <bool PerChannel, typename scalar_t>
|
||||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
|||||||
const float a_scale, const float* b_scale,
|
const float a_scale, const float* b_scale,
|
||||||
const int32_t* azp_with_adj, const int num_tokens,
|
const int32_t* azp_with_adj, const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
TORCH_CHECK(
|
||||||
|
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
|||||||
const int32_t* azp, const int32_t* azp_with_adj,
|
const int32_t* azp, const int32_t* azp_with_adj,
|
||||||
const scalar_t* bias, const int num_tokens,
|
const scalar_t* bias, const int num_tokens,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
TORCH_CHECK(false,
|
TORCH_CHECK(
|
||||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
false,
|
||||||
|
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||||
|
|
||||||
// Quantization
|
// Quantization
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||||
|
|
||||||
// Compute int8 quantized tensor for given scaling factor.
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user