mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:15:01 +08:00
[CPU] Refactor CPU WNA16 (#28826)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
40b6b38f2c
commit
20852c8f4c
@ -73,12 +73,11 @@ function cpu_tests() {
|
||||
pytest -x -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# pytest -x -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
# Run AWQ/GPTQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -x -s -v \
|
||||
tests/quantization/test_cpu_wna16.py"
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
||||
@ -375,6 +375,7 @@ set(VLLM_EXT_SRC
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#ifndef CPU_ATTN_HPP
|
||||
#define CPU_ATTN_HPP
|
||||
|
||||
#include <unistd.h>
|
||||
#include <type_traits>
|
||||
#include <cstddef>
|
||||
|
||||
@ -12,6 +11,7 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "cpu_attn_macros.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
enum class ISA { AMX, VEC, VEC16 };
|
||||
|
||||
@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
explicit FP16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
explicit BF16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
// de-pack 4 bit values
|
||||
explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
|
||||
int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
|
||||
int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
|
||||
int64_t value_0 = value & mask_0;
|
||||
int64_t value_1 = value & mask_1;
|
||||
__m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
|
||||
__m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
|
||||
vec_0 = _mm_cvtepu8_epi16(vec_0);
|
||||
vec_1 = _mm_cvtepu8_epi16(vec_1);
|
||||
vec_1 = _mm_slli_epi16(vec_1, 4);
|
||||
__m128i vec = _mm_or_si128(vec_0, vec_1);
|
||||
__m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
|
||||
reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
|
||||
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
|
||||
_mm512_stream_ps((float*)ptr, vec.reg);
|
||||
}
|
||||
|
||||
static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
|
||||
void* ptr) {
|
||||
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
|
||||
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
|
||||
vec_1 = _mm512_slli_epi32(vec_1, 16);
|
||||
vec_0 = _mm512_or_si512(vec_0, vec_1);
|
||||
_mm512_storeu_epi32(ptr, vec_0);
|
||||
}
|
||||
|
||||
static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
|
||||
void* ptr) {
|
||||
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
|
||||
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
|
||||
vec_1 = _mm512_slli_epi32(vec_1, 16);
|
||||
vec_0 = _mm512_or_si512(vec_0, vec_1);
|
||||
_mm512_storeu_epi32(ptr, vec_0);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline void mem_barrier() { _mm_mfence(); }
|
||||
|
||||
402
csrc/cpu/cpu_wna16.cpp
Normal file
402
csrc/cpu/cpu_wna16.cpp
Normal file
@ -0,0 +1,402 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "scratchpad_manager.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
|
||||
#endif
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
|
||||
|
||||
#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
|
||||
|
||||
template <typename T>
|
||||
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
|
||||
int32_t stride) {
|
||||
std::stringstream ss;
|
||||
ss << std::fixed << std::setprecision(5) << name << ": [\n";
|
||||
auto* curr_logits_buffer = ptr;
|
||||
for (int32_t m = 0; m < row; ++m) {
|
||||
for (int32_t n = 0; n < col; ++n) {
|
||||
ss << curr_logits_buffer[n] << ", ";
|
||||
}
|
||||
ss << "\n";
|
||||
curr_logits_buffer += stride;
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s", ss.str().c_str());
|
||||
}
|
||||
|
||||
namespace {
|
||||
using cpu_utils::ISA;
|
||||
using cpu_utils::VecTypeTrait;
|
||||
|
||||
template <typename scalar_t, ISA isa, bool has_zp, bool use_desc_act>
|
||||
class Dequantizer4b {
|
||||
public:
|
||||
constexpr static int32_t pack_num = 32 / 4;
|
||||
using scalar_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
public:
|
||||
static void dequant(int32_t* __restrict__ q_weight,
|
||||
scalar_t* __restrict__ weight,
|
||||
scalar_t* __restrict__ scales,
|
||||
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
|
||||
const int64_t scales_stride, const int64_t zeros_stride,
|
||||
const int32_t k_size, const int32_t group_size) {
|
||||
vec_op::FP32Vec16 lut;
|
||||
if constexpr (has_zp) {
|
||||
// AWQ
|
||||
alignas(64) static const float LUT[16] = {
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
|
||||
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
|
||||
lut = vec_op::FP32Vec16(LUT);
|
||||
} else {
|
||||
// GPTQ
|
||||
alignas(64) static const float LUT[16] = {
|
||||
-8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
|
||||
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
|
||||
lut = vec_op::FP32Vec16(LUT);
|
||||
}
|
||||
|
||||
// per 64-bits elem contains 16 output channels
|
||||
int64_t* __restrict__ curr_q_weight = reinterpret_cast<int64_t*>(q_weight);
|
||||
int64_t* __restrict__ curr_zeros = reinterpret_cast<int64_t*>(zeros);
|
||||
scalar_t* __restrict__ curr_weight = weight;
|
||||
scalar_t* __restrict__ curr_scale = scales;
|
||||
vec_op::FP32Vec16 scale_0;
|
||||
vec_op::FP32Vec16 scale_1;
|
||||
vec_op::FP32Vec16 zero_0;
|
||||
vec_op::FP32Vec16 zero_1;
|
||||
int32_t group_counter = 0;
|
||||
for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
|
||||
int64_t qwb_0 = *curr_q_weight;
|
||||
int64_t qwb_1 = *(curr_q_weight + 1);
|
||||
vec_op::FP32Vec16 wb_0(qwb_0, lut);
|
||||
vec_op::FP32Vec16 wb_1(qwb_1, lut);
|
||||
|
||||
if constexpr (!use_desc_act) {
|
||||
if (group_counter == 0) {
|
||||
scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
|
||||
scale_1 = vec_op::FP32Vec16(scale_0);
|
||||
curr_scale += scales_stride;
|
||||
|
||||
if constexpr (has_zp) {
|
||||
zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
|
||||
zero_1 = vec_op::FP32Vec16(zero_0);
|
||||
curr_zeros += zeros_stride / 2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int32_t g_idx_0 = g_idx[k_idx];
|
||||
int32_t g_idx_1 = g_idx[k_idx + 1];
|
||||
scale_0 = vec_op::FP32Vec16(
|
||||
scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
|
||||
scale_1 = vec_op::FP32Vec16(
|
||||
scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
|
||||
if constexpr (has_zp) {
|
||||
zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
|
||||
lut);
|
||||
zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
|
||||
lut);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp) {
|
||||
wb_0 = wb_0 - zero_0;
|
||||
wb_1 = wb_1 - zero_1;
|
||||
}
|
||||
|
||||
wb_0 = wb_0 * scale_0;
|
||||
wb_1 = wb_1 * scale_1;
|
||||
|
||||
scalar_vec_t output_vec_0(wb_0);
|
||||
scalar_vec_t output_vec_1(wb_1);
|
||||
|
||||
// AMX needs to interlave K elements to pack as 32 bits
|
||||
if constexpr (isa == ISA::AMX) {
|
||||
vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
|
||||
} else {
|
||||
output_vec_0.save(curr_weight);
|
||||
output_vec_1.save(curr_weight + 16);
|
||||
}
|
||||
|
||||
// update
|
||||
curr_q_weight += 2;
|
||||
curr_weight += 32;
|
||||
if constexpr (!use_desc_act) {
|
||||
group_counter += 2;
|
||||
if (group_counter == group_size) {
|
||||
group_counter = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t, typename dequantizer_t, typename gemm_t>
|
||||
void cpu_gemm_wna16_impl(
|
||||
scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
|
||||
scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
|
||||
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
|
||||
scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
|
||||
const int32_t k_size, const int64_t input_stride,
|
||||
const int64_t output_stride, const int64_t scales_group_stride,
|
||||
const int64_t zeros_group_stride, const int32_t group_num,
|
||||
const int32_t group_size, const int64_t pack_factor) {
|
||||
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
|
||||
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
|
||||
constexpr int32_t n_block_size = 16;
|
||||
static_assert(gemm_n_tile_size % n_block_size == 0);
|
||||
const int32_t thread_num = omp_get_max_threads();
|
||||
|
||||
// a simple schedule policy, just to hold more B tiles in L2 and make sure
|
||||
// each thread has tasks
|
||||
const int32_t n_partition_size = [&]() {
|
||||
const int64_t cache_size = cpu_utils::get_l2_size();
|
||||
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
|
||||
int64_t ps_thread_limit = n_size / thread_num;
|
||||
ps_cache_limit =
|
||||
std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
|
||||
(int64_t)gemm_n_tile_size);
|
||||
ps_thread_limit =
|
||||
std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
|
||||
(int64_t)gemm_n_tile_size);
|
||||
return std::min(ps_cache_limit, ps_thread_limit);
|
||||
}();
|
||||
const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
|
||||
|
||||
// get buffer size
|
||||
const int64_t b_buffer_size =
|
||||
(((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
|
||||
const int64_t c_buffer_size =
|
||||
(((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
|
||||
const int64_t b_buffer_offset = 0;
|
||||
const int64_t c_buffer_offset = b_buffer_size;
|
||||
const int64_t buffer_size = b_buffer_size + c_buffer_size;
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
|
||||
thread_num);
|
||||
|
||||
alignas(64) cpu_utils::Counter counter;
|
||||
cpu_utils::Counter* counter_ptr = &counter;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
|
||||
scalar_t* __restrict__ b_buffer = nullptr;
|
||||
float* __restrict__ c_buffer = nullptr;
|
||||
{
|
||||
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
|
||||
->get_data<uint8_t>() +
|
||||
thread_id * buffer_size;
|
||||
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
|
||||
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
|
||||
}
|
||||
|
||||
const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
|
||||
const int64_t b_buffer_block_stride = n_block_size * k_size;
|
||||
const int32_t zeros_block_stride = n_block_size / pack_factor;
|
||||
|
||||
gemm_t gemm;
|
||||
|
||||
for (;;) {
|
||||
int32_t task_id = counter_ptr->acquire_counter();
|
||||
|
||||
if (task_id >= task_num) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int32_t n_start_idx = task_id * n_partition_size;
|
||||
const int32_t n_block_start_idx = n_start_idx / n_block_size;
|
||||
const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
|
||||
const int32_t n_block_num = n_num / n_block_size;
|
||||
// std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
|
||||
// thread_id, task_id, n_start_idx, n_num);
|
||||
|
||||
// dequant weight
|
||||
{
|
||||
int32_t* __restrict__ curr_q_weight =
|
||||
q_weight + n_block_start_idx * q_weight_block_stride;
|
||||
scalar_t* __restrict__ curr_b_buffer = b_buffer;
|
||||
scalar_t* __restrict__ curr_scales = scales + n_start_idx;
|
||||
int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
|
||||
for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
|
||||
dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
|
||||
curr_zeros, g_idx, scales_group_stride,
|
||||
zeros_group_stride, k_size, group_size);
|
||||
|
||||
// if (block_idx == 0 && n_start_idx == 0) {
|
||||
// print_logits("depacked weight", curr_b_buffer, k_size,
|
||||
// n_block_size, n_block_size);
|
||||
// }
|
||||
|
||||
// update
|
||||
curr_q_weight += q_weight_block_stride;
|
||||
curr_b_buffer += b_buffer_block_stride;
|
||||
curr_scales += n_block_size;
|
||||
curr_zeros += zeros_block_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// compute loop
|
||||
{
|
||||
const int32_t n_tile_num = n_num / gemm_n_tile_size;
|
||||
scalar_t* __restrict__ curr_input = input;
|
||||
scalar_t* __restrict__ init_bias = bias;
|
||||
if (bias != nullptr) {
|
||||
init_bias += n_start_idx;
|
||||
}
|
||||
scalar_t* __restrict__ init_output = output + n_start_idx;
|
||||
for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
|
||||
const int32_t curr_m_size =
|
||||
std::min(gemm_m_tile_size, m_size - m_idx);
|
||||
scalar_t* __restrict__ curr_b_buffer = b_buffer;
|
||||
scalar_t* __restrict__ curr_bias = init_bias;
|
||||
scalar_t* __restrict__ curr_output = init_output;
|
||||
for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
|
||||
gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
|
||||
input_stride, b_buffer_block_stride, gemm_n_tile_size,
|
||||
false);
|
||||
|
||||
if (bias != nullptr) {
|
||||
cpu_micro_gemm::bias_epilogue<gemm_n_tile_size>(
|
||||
c_buffer, curr_output, curr_bias, curr_m_size,
|
||||
gemm_n_tile_size, output_stride);
|
||||
curr_bias += gemm_n_tile_size;
|
||||
} else {
|
||||
cpu_micro_gemm::default_epilogue<gemm_n_tile_size>(
|
||||
c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
|
||||
output_stride);
|
||||
}
|
||||
|
||||
curr_b_buffer +=
|
||||
b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
|
||||
curr_output += gemm_n_tile_size;
|
||||
}
|
||||
curr_input += gemm_m_tile_size * input_stride;
|
||||
init_output += gemm_m_tile_size * output_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cpu_gemm_wna16(
|
||||
const torch::Tensor& input, // [M, K]
|
||||
const torch::Tensor&
|
||||
q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
|
||||
torch::Tensor& output, // [M, N]
|
||||
const torch::Tensor& scales, // [group_num, N]
|
||||
const std::optional<torch::Tensor>&
|
||||
zeros, // [group_num, N / pack_factor], packed as int32
|
||||
const std::optional<torch::Tensor>& g_idx, // [K]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
const int64_t pack_factor, const std::string& isa_hint) {
|
||||
using cpu_utils::ISA;
|
||||
TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
|
||||
const int32_t a_m_size = input.size(0);
|
||||
const int32_t a_k_size = input.size(1);
|
||||
const int64_t a_m_stride = input.stride(0);
|
||||
const int32_t b_n_size = q_weight.size(0) * 16;
|
||||
TORCH_CHECK_EQ(a_k_size % 32, 0);
|
||||
TORCH_CHECK_EQ(b_n_size % 32, 0);
|
||||
const int32_t group_num = scales.size(0);
|
||||
const int32_t group_size = a_k_size / group_num;
|
||||
TORCH_CHECK_EQ(group_size % 2, 0);
|
||||
const int64_t scales_group_stride = scales.stride(0);
|
||||
const int64_t output_m_stride = output.stride(0);
|
||||
|
||||
bool has_zp = zeros.has_value();
|
||||
bool use_desc_act = g_idx.has_value();
|
||||
TORCH_CHECK(!(has_zp && use_desc_act));
|
||||
|
||||
ISA isa = [&]() {
|
||||
if (isa_hint == "amx") {
|
||||
return ISA::AMX;
|
||||
} else if (isa_hint == "vec") {
|
||||
return ISA::VEC;
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
|
||||
}
|
||||
}();
|
||||
|
||||
int32_t* zeros_ptr = has_zp ? zeros->data_ptr<int32_t>() : nullptr;
|
||||
const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
|
||||
int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr<int32_t>() : nullptr;
|
||||
|
||||
VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
|
||||
if (isa == ISA::AMX) {
|
||||
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::AMX, scalar_t>;
|
||||
if (has_zp) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, true, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
if (use_desc_act) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, true>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
} else {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
} else if (isa == ISA::VEC) {
|
||||
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::VEC, scalar_t>;
|
||||
if (has_zp) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, true, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
if (use_desc_act) {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, true>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
} else {
|
||||
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, false>;
|
||||
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
|
||||
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
|
||||
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
|
||||
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
|
||||
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
|
||||
scales_group_stride, zeros_group_stride, group_num, group_size,
|
||||
pack_factor);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||
ab_type_ == dnnl::memory::data_type::f16);
|
||||
assert(b_type_ == dnnl::memory::data_type::f32 ||
|
||||
b_type_ == dnnl::memory::data_type::bf16 ||
|
||||
b_type_ == dnnl::memory::data_type::f16);
|
||||
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
|
||||
245
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
Normal file
245
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
Normal file
@ -0,0 +1,245 @@
|
||||
#ifndef CPU_MICRO_GEMM_AMX_HPP
|
||||
#define CPU_MICRO_GEMM_AMX_HPP
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
namespace {
|
||||
// AMX specific
|
||||
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
|
||||
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
|
||||
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
|
||||
|
||||
typedef struct __tile_config {
|
||||
uint8_t palette_id = 1;
|
||||
uint8_t start_row = 0;
|
||||
uint8_t reserved_0[14] = {0};
|
||||
uint16_t colsb[16] = {0};
|
||||
uint8_t rows[16] = {0};
|
||||
} __tilecfg;
|
||||
|
||||
// 2-2-4 pattern, for 16 < m <= 32
|
||||
// TILE 0, 1: load A matrix, row num should be 16, m - 16
|
||||
// TILE 2, 3: load B matrix, row num should be 16
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
|
||||
// - 16
|
||||
template <typename scalar_t>
|
||||
class TileGemm224 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm224<c10::BFloat16> {
|
||||
public:
|
||||
using scalar_t = c10::BFloat16;
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
|
||||
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
|
||||
|
||||
// B is always packed as 16 output channels block
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
|
||||
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_4 = c_ptr;
|
||||
float* __restrict__ c_tile_5 =
|
||||
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
|
||||
float* __restrict__ c_tile_7 =
|
||||
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
const int32_t c_tile_stride = ldc * sizeof(float);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(4, c_tile_4, c_tile_stride);
|
||||
_tile_loadd(5, c_tile_5, c_tile_stride);
|
||||
_tile_loadd(6, c_tile_6, c_tile_stride);
|
||||
_tile_loadd(7, c_tile_7, c_tile_stride);
|
||||
} else {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
|
||||
// update ptrs
|
||||
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
_tile_stored(4, c_tile_4, c_tile_stride);
|
||||
_tile_stored(5, c_tile_5, c_tile_stride);
|
||||
_tile_stored(6, c_tile_6, c_tile_stride);
|
||||
_tile_stored(7, c_tile_7, c_tile_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
const int32_t m_0 = AMX_TILE_ROW_NUM;
|
||||
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
|
||||
config.rows[0] = m_0;
|
||||
config.rows[1] = m_1;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = m_0;
|
||||
config.rows[5] = m_0;
|
||||
config.rows[6] = m_1;
|
||||
config.rows[7] = m_1;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
|
||||
// 1-2-2 pattern, for 0 < m <= 16
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
|
||||
// m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
|
||||
// num should be 16
|
||||
// TILE 6, 7, (6, 7): store results C matrix, row num should be
|
||||
// m
|
||||
template <typename scalar_t>
|
||||
class TileGemm122 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm122<c10::BFloat16> {
|
||||
public:
|
||||
using scalar_t = c10::BFloat16;
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
|
||||
c10::BFloat16* __restrict__ a_tile_1 =
|
||||
a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
|
||||
c10::BFloat16* __restrict__ b_tile_4 =
|
||||
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
c10::BFloat16* __restrict__ b_tile_5 =
|
||||
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
int64_t b_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_6 = c_ptr;
|
||||
float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
int64_t c_stride = ldc * sizeof(float);
|
||||
|
||||
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
const int32_t k_group_times = k_times / 2;
|
||||
const bool has_tail = (k_times % 2 == 1);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(6, c_tile_6, c_stride);
|
||||
_tile_loadd(7, c_tile_7, c_stride);
|
||||
} else {
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_group_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_stream_loadd(4, b_tile_4, b_stride);
|
||||
_tile_dpbf16ps(6, 1, 4);
|
||||
_tile_stream_loadd(5, b_tile_5, b_stride);
|
||||
_tile_dpbf16ps(7, 1, 5);
|
||||
|
||||
// update ptrs
|
||||
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
if (has_tail) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
}
|
||||
|
||||
_tile_stored(6, c_tile_6, c_stride);
|
||||
_tile_stored(7, c_tile_7, c_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
config.rows[0] = m;
|
||||
config.rows[1] = m;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = AMX_TILE_ROW_NUM;
|
||||
config.rows[5] = AMX_TILE_ROW_NUM;
|
||||
config.rows[6] = m;
|
||||
config.rows[7] = m;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Gemm kernel uses AMX, requires B matrix to be packed
|
||||
template <typename scalar_t>
|
||||
class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 32;
|
||||
static constexpr int32_t NSize = 32;
|
||||
|
||||
public:
|
||||
MicroGemm() : curr_m_(-1) {
|
||||
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
|
||||
}
|
||||
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
if (m > AMX_TILE_ROW_NUM) {
|
||||
if (m != curr_m_) {
|
||||
curr_m_ = m;
|
||||
TileGemm224<scalar_t>::init_tile_config(m, amx_tile_config_);
|
||||
}
|
||||
TileGemm224<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
} else {
|
||||
if (m != curr_m_) {
|
||||
curr_m_ = m;
|
||||
TileGemm122<scalar_t>::init_tile_config(m, amx_tile_config_);
|
||||
}
|
||||
TileGemm122<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
alignas(64) __tilecfg amx_tile_config_;
|
||||
int32_t curr_m_;
|
||||
};
|
||||
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
91
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
Normal file
91
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
Normal file
@ -0,0 +1,91 @@
|
||||
#ifndef CPU_MICRO_GEMM_IMPL_HPP
|
||||
#define CPU_MICRO_GEMM_IMPL_HPP
|
||||
#include "cpu/utils.hpp"
|
||||
#include "cpu/cpu_types.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
#define DEFINE_CPU_MICRO_GEMM_PARAMS \
|
||||
scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
|
||||
float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
|
||||
const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
|
||||
const bool accum_c
|
||||
|
||||
#define CPU_MICRO_GEMM_PARAMS \
|
||||
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
|
||||
|
||||
template <cpu_utils::ISA isa, typename scalar_t>
|
||||
class MicroGemm {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 16;
|
||||
static constexpr int32_t NSize = 16;
|
||||
|
||||
public:
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TORCH_CHECK(false, "Unimplemented MicroGemm.");
|
||||
}
|
||||
};
|
||||
|
||||
template <int32_t n_size, typename scalar_t>
|
||||
FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
|
||||
scalar_t* __restrict__ d_ptr,
|
||||
const int32_t m, const int64_t ldc,
|
||||
const int64_t ldd) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
static_assert(n_size % 16 == 0);
|
||||
|
||||
float* __restrict__ curr_c = c_ptr;
|
||||
scalar_t* __restrict__ curr_d = d_ptr;
|
||||
for (int32_t i = 0; i < m; ++i) {
|
||||
float* __restrict__ curr_c_iter = curr_c;
|
||||
scalar_t* __restrict__ curr_d_iter = curr_d;
|
||||
vec_op::unroll_loop<int32_t, n_size / 16>([&](int32_t n_g_idx) {
|
||||
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
|
||||
scalar_vec_t c_vec(c_vec_fp32);
|
||||
c_vec.save(curr_d_iter);
|
||||
curr_c_iter += 16;
|
||||
curr_d_iter += 16;
|
||||
});
|
||||
curr_c += ldc;
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t n_size, typename scalar_t>
|
||||
FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
|
||||
scalar_t* __restrict__ d_ptr,
|
||||
scalar_t* __restrict__ bias_ptr,
|
||||
const int32_t m, const int64_t ldc,
|
||||
const int64_t ldd) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
static_assert(n_size % 16 == 0);
|
||||
constexpr int32_t n_group_num = n_size / 16;
|
||||
static_assert(n_group_num <= 16);
|
||||
|
||||
vec_op::FP32Vec16 bias_vecs[n_group_num];
|
||||
scalar_t* __restrict__ curr_bias = bias_ptr;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
|
||||
scalar_vec_t vec(curr_bias);
|
||||
bias_vecs[i] = vec_op::FP32Vec16(vec);
|
||||
curr_bias += 16;
|
||||
});
|
||||
|
||||
float* __restrict__ curr_c = c_ptr;
|
||||
scalar_t* __restrict__ curr_d = d_ptr;
|
||||
for (int32_t i = 0; i < m; ++i) {
|
||||
float* __restrict__ curr_c_iter = curr_c;
|
||||
scalar_t* __restrict__ curr_d_iter = curr_d;
|
||||
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
|
||||
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
|
||||
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
|
||||
scalar_vec_t c_vec(c_vec_fp32);
|
||||
c_vec.save(curr_d_iter);
|
||||
curr_c_iter += 16;
|
||||
curr_d_iter += 16;
|
||||
});
|
||||
curr_c += ldc;
|
||||
curr_d += ldd;
|
||||
}
|
||||
}
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
115
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
Normal file
115
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
Normal file
@ -0,0 +1,115 @@
|
||||
#ifndef CPU_MICRO_GEMM_VEC_HPP
|
||||
#define CPU_MICRO_GEMM_VEC_HPP
|
||||
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
|
||||
|
||||
namespace cpu_micro_gemm {
|
||||
namespace {
|
||||
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
|
||||
template <typename scalar_t>
|
||||
class TileGemm82 {
|
||||
public:
|
||||
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
switch (m) {
|
||||
case 1:
|
||||
gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 3:
|
||||
gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 4:
|
||||
gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 5:
|
||||
gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 6:
|
||||
gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 7:
|
||||
gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
case 8:
|
||||
gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
static_assert(0 < M <= 8);
|
||||
using load_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
scalar_t* __restrict__ curr_b_0 = b_ptr;
|
||||
scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
|
||||
float* __restrict__ curr_c_0 = c_ptr;
|
||||
float* __restrict__ curr_c_1 = c_ptr + 16;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M * 2];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
float* __restrict__ curr_m_c_1 = curr_c_1;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
curr_m_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
scalar_t* __restrict__ curr_a = a_ptr;
|
||||
for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
load_vec_t b_1_reg(curr_b_1);
|
||||
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
|
||||
|
||||
scalar_t* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
scalar_t v = *curr_m_a;
|
||||
load_vec_t a_reg_original(v);
|
||||
vec_op::FP32Vec16 a_reg(a_reg_original);
|
||||
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
|
||||
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += 16;
|
||||
curr_b_1 += 16;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2].save(curr_c_0);
|
||||
c_regs[i * 2 + 1].save(curr_c_1);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
curr_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Gemm kernel uses vector instructions, requires B matrix to be packed
|
||||
template <typename scalar_t>
|
||||
class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
|
||||
public:
|
||||
static constexpr int32_t MaxMSize = 8;
|
||||
static constexpr int32_t NSize = 32;
|
||||
|
||||
public:
|
||||
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
|
||||
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
|
||||
}
|
||||
};
|
||||
} // namespace cpu_micro_gemm
|
||||
|
||||
#endif
|
||||
@ -103,6 +103,13 @@ void cpu_attention_with_kv_cache(
|
||||
// Note: just for avoiding importing errors
|
||||
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
|
||||
|
||||
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
|
||||
torch::Tensor& output, const torch::Tensor& scales,
|
||||
const std::optional<torch::Tensor>& zeros,
|
||||
const std::optional<torch::Tensor>& g_idx,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
const int64_t pack_factor, const std::string& isa_hint);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -283,6 +290,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
|
||||
// WNA16
|
||||
#if defined(__AVX512F__)
|
||||
ops.def(
|
||||
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
|
||||
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
|
||||
"pack_factor, str isa_hint) -> ()");
|
||||
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
55
csrc/cpu/utils.hpp
Normal file
55
csrc/cpu/utils.hpp
Normal file
@ -0,0 +1,55 @@
|
||||
#ifndef UTILS_HPP
|
||||
#define UTILS_HPP
|
||||
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace cpu_utils {
|
||||
enum class ISA { AMX, VEC };
|
||||
|
||||
template <typename T>
|
||||
struct VecTypeTrait {
|
||||
using vec_t = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<float> {
|
||||
using vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<c10::BFloat16> {
|
||||
using vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecTypeTrait<c10::Half> {
|
||||
using vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
struct Counter {
|
||||
std::atomic<int64_t> counter;
|
||||
char _padding[56];
|
||||
|
||||
Counter() : counter(0) {}
|
||||
|
||||
void reset_counter() { counter.store(0); }
|
||||
|
||||
int64_t acquire_counter() { return counter++; }
|
||||
};
|
||||
|
||||
inline int64_t get_l2_size() {
|
||||
static int64_t size = []() {
|
||||
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
|
||||
assert(l2_cache_size != -1);
|
||||
return l2_cache_size >> 1; // use 50% of L2 cache
|
||||
}();
|
||||
return size;
|
||||
}
|
||||
} // namespace cpu_utils
|
||||
|
||||
#endif
|
||||
@ -97,7 +97,6 @@ Currently, there are no pre-built CPU wheels.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
|
||||
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
|
||||
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
|
||||
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
|
||||
|
||||
## FAQ
|
||||
@ -191,10 +190,9 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel
|
||||
- GPTQ (x86 only)
|
||||
- compressed-tensor INT8 W8A8 (x86, s390x)
|
||||
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios.
|
||||
|
||||
### Why do I see `get_mempolicy: Operation not permitted` when running in Docker?
|
||||
|
||||
@ -22,7 +22,6 @@ datasets # for benchmark scripts
|
||||
|
||||
# Intel Extension for PyTorch, only for x86_64 CPUs
|
||||
intel-openmp==2024.2.1; platform_machine == "x86_64"
|
||||
intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64"
|
||||
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
|
||||
|
||||
# Use this to gather CPU info and optimize based on ARM Neoverse cores
|
||||
|
||||
23
tests/quantization/test_cpu_wna16.py
Normal file
23
tests/quantization/test_cpu_wna16.py
Normal file
@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
MODELS = [
|
||||
"TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ",
|
||||
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx
|
||||
]
|
||||
DTYPE = ["bfloat16"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
def test_ipex_quant(vllm_runner, model, dtype):
|
||||
with vllm_runner(model, dtype=dtype) as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
||||
assert output
|
||||
print(output)
|
||||
@ -2702,6 +2702,31 @@ def cpu_attention_with_kv_cache(
|
||||
)
|
||||
|
||||
|
||||
def cpu_gemm_wna16(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
bias: torch.Tensor | None,
|
||||
pack_factor: int,
|
||||
isa_hint: str,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype)
|
||||
torch.ops._C.cpu_gemm_wna16(
|
||||
input,
|
||||
q_weight,
|
||||
output,
|
||||
scales,
|
||||
zeros,
|
||||
g_idx,
|
||||
bias,
|
||||
pack_factor,
|
||||
isa_hint,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
|
||||
|
||||
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
|
||||
|
||||
@ -1020,6 +1020,8 @@ class ModelConfig:
|
||||
# Ensure heavy backends are probed last to avoid unnecessary
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
"mxfp4",
|
||||
"cpu_gptq",
|
||||
"cpu_awq",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
||||
@ -50,7 +50,6 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_KVCACHE_SPACE: int | None = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
||||
VLLM_CPU_MOE_PREPACK: bool = True
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
@ -665,10 +664,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
)
|
||||
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
|
||||
else None,
|
||||
# (CPU backend only) whether to use prepack for MoE layer. This will be
|
||||
# passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might
|
||||
# need to set this to "0" (False).
|
||||
"VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))),
|
||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||
# If the env var is set, Ray Compiled Graph uses the specified
|
||||
|
||||
@ -6,7 +6,6 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
|
||||
|
||||
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -130,54 +129,6 @@ def select_experts(
|
||||
)
|
||||
|
||||
|
||||
class IPEXFusedMOE:
|
||||
def __init__(self, layer: torch.nn.Module) -> None:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
use_prepack=envs.VLLM_CPU_MOE_PREPACK,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported."
|
||||
)
|
||||
return layer.ipex_fusion(
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
class SGLFusedMOE:
|
||||
def __init__(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
@ -260,7 +260,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w2_weight.copy_(packed_w2_weight)
|
||||
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
|
||||
|
||||
@ -38,6 +38,8 @@ QuantizationMethods = Literal[
|
||||
"inc",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"cpu_gptq",
|
||||
"cpu_awq",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -107,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig
|
||||
from .deepspeedfp import DeepSpeedFPConfig
|
||||
from .experts_int8 import ExpertsInt8Config
|
||||
from .fbgemm_fp8 import FBGEMMFp8Config
|
||||
@ -159,6 +162,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_gptq": CPUGPTQConfig,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
625
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
625
vllm/model_executor/layers/quantization/cpu_wna16.py
Normal file
@ -0,0 +1,625 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
from vllm._custom_ops import (
|
||||
cpu_gemm_wna16,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
pack_cols,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUGPTQConfig(QuantizationConfig):
|
||||
"""Config class for CPU GPTQ quant"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
full_config: dict[str, Any],
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
assert weight_bits == 4
|
||||
self.dynamic = dynamic
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"CPUWNA16Config("
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "cpu_gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
is_sym,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
config,
|
||||
modules_in_block_to_quantize,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "gptq"):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item
|
||||
for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class CPUGPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ on CPU.
|
||||
|
||||
Args:
|
||||
quant_config: The CPUWNA16 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CPUGPTQConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU"
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0
|
||||
assert output_size_per_partition % 32 == 0
|
||||
assert input_size_per_partition % 32 == 0
|
||||
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(
|
||||
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
|
||||
):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each rank in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(
|
||||
g_idx,
|
||||
{"ignore_warning": True},
|
||||
)
|
||||
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
|
||||
packed_weight = layer.qweight.data
|
||||
bits = self.quant_config.weight_bits
|
||||
pack_factor = int(self.quant_config.pack_factor)
|
||||
p_w_k, p_w_n = packed_weight.size()
|
||||
input_size = p_w_k * pack_factor
|
||||
output_size = p_w_n
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
layer.qzeros = None
|
||||
if not self.quant_config.desc_act:
|
||||
layer.g_idx = None
|
||||
|
||||
# convert input dim packed to output dim packed
|
||||
weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view(
|
||||
p_w_k, p_w_n, pack_factor
|
||||
)
|
||||
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
|
||||
weight = pack_cols(weight, bits, input_size, output_size)
|
||||
# make 16 output channel as a block and transpose to the make
|
||||
# the block contigous
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
bias=bias,
|
||||
pack_factor=8,
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class CPUAWQConfig(QuantizationConfig):
|
||||
"""Config class for CPU AWQ"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: list[str] | None,
|
||||
full_config: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert weight_bits == 4
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"AWQMarlinConfig("
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> "QuantizationMethods":
|
||||
return "cpu_awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
zero_point,
|
||||
lm_head_quantized,
|
||||
modules_to_not_convert,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "awq"):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return CPUAWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class CPUAWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for CPU AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The CPU AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CPUAWQConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.zero_point
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
scales = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False)
|
||||
packed_weight = layer.qweight.data
|
||||
packed_zeros = layer.qzeros.data
|
||||
group_num = packed_zeros.size(0)
|
||||
bits = self.quant_config.weight_bits
|
||||
pack_factor = int(self.quant_config.pack_factor)
|
||||
input_size, packed_output_size = packed_weight.size()
|
||||
output_size = packed_output_size * pack_factor
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
interleave_map = (0, 4, 1, 5, 2, 6, 3, 7)
|
||||
weight = unpack_cols(
|
||||
packed_weight,
|
||||
bits,
|
||||
input_size,
|
||||
output_size,
|
||||
)
|
||||
zeros = unpack_cols(
|
||||
packed_zeros,
|
||||
bits,
|
||||
group_num,
|
||||
output_size,
|
||||
)
|
||||
weight = (
|
||||
weight.view(input_size, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(input_size, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
zeros = (
|
||||
zeros.view(group_num, -1, pack_factor)[:, :, interleave_map]
|
||||
.reshape(group_num, output_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
|
||||
# make 16 output channel as a block and transpose to
|
||||
# the make the block contigous
|
||||
weight = pack_cols(weight, bits, input_size, output_size)
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
layer.qzeros.data = zeros
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=None,
|
||||
bias=bias,
|
||||
pack_factor=8,
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
@ -134,7 +134,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
if not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user