mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Hardware][Power] Enable compressed tensor W8A8 INT8 quantization for POWER (#17153)
Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
5a499e70d5
commit
e515668edf
@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
elseif(POWER10_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.2
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(DNNL_CPU_RUNTIME "OMP")
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
endif()
|
||||
|
||||
@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -214,4 +245,4 @@ define_gpu_extension_target(
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
message(STATUS "Enabling C extension.")
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
#include <altivec.h>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace vec_op {
|
||||
@ -62,6 +63,10 @@ typedef struct f32x4x4_t {
|
||||
__vector float val[4];
|
||||
} f32x4x4_t;
|
||||
|
||||
typedef struct i32x4x4_t {
|
||||
__vector int32_t val[4];
|
||||
} i32x4x4_t;
|
||||
|
||||
struct FP32Vec8;
|
||||
struct FP32Vec16;
|
||||
|
||||
@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
const int clamped_elem = std::max(0, std::min(elem_num, 16));
|
||||
|
||||
// Calculate elements to store in each 128-bit part (8 elements each)
|
||||
const int elements_val0 = std::min(clamped_elem, 8);
|
||||
const int elements_val1 = std::max(clamped_elem - 8, 0);
|
||||
|
||||
// Convert elements to bytes (2 bytes per element)
|
||||
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
|
||||
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
|
||||
|
||||
signed short* dest = static_cast<signed short*>(ptr);
|
||||
// Store the first part using vec_xst_len
|
||||
if (bytes_val0 > 0) {
|
||||
vec_xst_len(reg.val[0], dest, bytes_val0);
|
||||
}
|
||||
// Store the second part if needed
|
||||
if (bytes_val1 > 0) {
|
||||
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const static __vector signed short zero = vec_splats((signed short)0);
|
||||
@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
i32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
i32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* data_ptr) {
|
||||
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
|
||||
reg.val[1] =
|
||||
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
|
||||
reg.val[2] =
|
||||
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
|
||||
reg.val[3] =
|
||||
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
|
||||
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
|
||||
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
|
||||
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
|
||||
}
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
const int elements_in_chunk1 =
|
||||
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
|
||||
const int elements_in_chunk2 =
|
||||
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
|
||||
const int elements_in_chunk3 =
|
||||
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
|
||||
const int elements_in_chunk4 =
|
||||
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
|
||||
|
||||
const size_t bytes_chunk1 =
|
||||
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
|
||||
const size_t bytes_chunk2 =
|
||||
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
|
||||
const size_t bytes_chunk3 =
|
||||
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
|
||||
const size_t bytes_chunk4 =
|
||||
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
|
||||
|
||||
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
|
||||
vec_xst_len(reg.val[1],
|
||||
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
|
||||
bytes_chunk2);
|
||||
vec_xst_len(reg.val[2],
|
||||
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
|
||||
bytes_chunk3);
|
||||
vec_xst_len(reg.val[3],
|
||||
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
|
||||
bytes_chunk4);
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vec_ctf(v.reg.val[0], 0);
|
||||
reg.val[1] = vec_ctf(v.reg.val[1], 0);
|
||||
reg.val[2] = vec_ctf(v.reg.val[2], 0);
|
||||
reg.val[3] = vec_ctf(v.reg.val[3], 0);
|
||||
}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(f32x4x4_t(
|
||||
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
|
||||
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
|
||||
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
|
||||
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
|
||||
vec_max(reg.val[1], b.reg.val[1]),
|
||||
vec_max(reg.val[2], b.reg.val[2]),
|
||||
vec_max(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
|
||||
FP32Vec16 result;
|
||||
|
||||
// Create a vector of element indices for each chunk
|
||||
__vector unsigned int indices = {0, 1, 2, 3};
|
||||
__vector unsigned int elem_num_vec =
|
||||
vec_splats(static_cast<unsigned int>(elem_num));
|
||||
|
||||
// Compute masks for each chunk
|
||||
__vector unsigned int chunk_offset0 = {0, 0, 0,
|
||||
0}; // Chunk 0: Elements 0-3
|
||||
__vector unsigned int chunk_offset1 = {4, 4, 4,
|
||||
4}; // Chunk 1: Elements 4-7
|
||||
__vector unsigned int chunk_offset2 = {8, 8, 8,
|
||||
8}; // Chunk 2: Elements 8-11
|
||||
__vector unsigned int chunk_offset3 = {12, 12, 12,
|
||||
12}; // Chunk 3: Elements 12-15
|
||||
|
||||
// Compute masks for each chunk
|
||||
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
|
||||
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
|
||||
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
|
||||
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
|
||||
|
||||
// Apply masks to compute the result for each chunk
|
||||
result.reg.val[0] = vec_sel(this->reg.val[0],
|
||||
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
|
||||
result.reg.val[1] = vec_sel(this->reg.val[1],
|
||||
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
|
||||
result.reg.val[2] = vec_sel(this->reg.val[2],
|
||||
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
|
||||
result.reg.val[3] = vec_sel(this->reg.val[3],
|
||||
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
|
||||
|
||||
return FP32Vec16(result.reg);
|
||||
}
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
|
||||
vec_min(reg.val[1], b.reg.val[1]),
|
||||
vec_min(reg.val[2], b.reg.val[2]),
|
||||
vec_min(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
|
||||
FP32Vec16 result;
|
||||
|
||||
vector unsigned int indices = {0, 1, 2, 3};
|
||||
vector unsigned int elem_num_vec =
|
||||
vec_splats(static_cast<unsigned int>(elem_num));
|
||||
|
||||
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
|
||||
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
|
||||
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
|
||||
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
|
||||
|
||||
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
|
||||
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
|
||||
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
|
||||
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
|
||||
|
||||
result.reg.val[0] = vec_sel(this->reg.val[0],
|
||||
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
|
||||
result.reg.val[1] = vec_sel(this->reg.val[1],
|
||||
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
|
||||
result.reg.val[2] = vec_sel(this->reg.val[2],
|
||||
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
|
||||
result.reg.val[3] = vec_sel(this->reg.val[3],
|
||||
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
|
||||
|
||||
return FP32Vec16(result.reg);
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
|
||||
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_max() {
|
||||
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
|
||||
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
|
||||
__vector float max_all = vec_max(max01, max23);
|
||||
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
|
||||
temp = vec_max(temp, vec_sld(temp, temp, 4));
|
||||
return vec_extract(temp, 0);
|
||||
}
|
||||
|
||||
float reduce_min() {
|
||||
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
|
||||
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
|
||||
__vector float min_all = vec_min(min01, min23);
|
||||
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
|
||||
temp = vec_min(temp, vec_sld(temp, temp, 4));
|
||||
return vec_extract(temp, 0);
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
vec_xst(reg.val[3], 48, ptr);
|
||||
}
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
const int elements_in_chunk1 =
|
||||
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
|
||||
const int elements_in_chunk2 =
|
||||
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
|
||||
const int elements_in_chunk3 =
|
||||
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
|
||||
const int elements_in_chunk4 =
|
||||
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
|
||||
|
||||
const size_t bytes_chunk1 =
|
||||
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
|
||||
const size_t bytes_chunk2 =
|
||||
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
|
||||
const size_t bytes_chunk3 =
|
||||
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
|
||||
const size_t bytes_chunk4 =
|
||||
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
|
||||
|
||||
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
|
||||
vec_xst_len(reg.val[1],
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
|
||||
bytes_chunk2);
|
||||
vec_xst_len(reg.val[2],
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
|
||||
bytes_chunk3);
|
||||
vec_xst_len(reg.val[3],
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
|
||||
bytes_chunk4);
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
|
||||
|
||||
union AliasReg {
|
||||
__vector signed char reg;
|
||||
int8_t values[VEC_NUM_ELEM];
|
||||
};
|
||||
|
||||
__vector signed char reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
__vector signed int ret[4];
|
||||
ret[0] = vec_cts(vec.reg.val[0], 0);
|
||||
ret[1] = vec_cts(vec.reg.val[1], 0);
|
||||
ret[2] = vec_cts(vec.reg.val[2], 0);
|
||||
ret[3] = vec_cts(vec.reg.val[3], 0);
|
||||
|
||||
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
|
||||
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
|
||||
|
||||
reg = vec_packs(packed1, packed2);
|
||||
}
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed char*>(ptr) = reg;
|
||||
}
|
||||
void save(signed char* ptr, const int elem_num) {
|
||||
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#elif defined(__powerpc64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.")
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
// We dont need this
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
#endif
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user