From 5eda21e773447d81ffc661ac094716420dc7b7cb Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Fri, 18 Oct 2024 00:21:04 +0800 Subject: [PATCH] [Hardware][CPU] compressed-tensor INT8 W8A8 AZP support (#9344) --- .buildkite/run-cpu-test.sh | 8 +- Dockerfile.cpu | 13 - cmake/cpu_extension.cmake | 40 +- csrc/cpu/cpu_types_x86.hpp | 41 +- csrc/cpu/quant.cpp | 417 +++++++++++++++--- csrc/cpu/torch_bindings.cpp | 15 + .../getting_started/cpu-installation.rst | 14 - 7 files changed, 452 insertions(+), 96 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index c2818c38965ea..c331a9c49c0d0 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -32,10 +32,10 @@ docker exec cpu-test bash -c " --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # Run compressed-tensor test -# docker exec cpu-test bash -c " -# pytest -s -v \ -# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ -# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test docker exec cpu-test bash -c " diff --git a/Dockerfile.cpu b/Dockerfile.cpu index b9134d4ae41cb..2e7d66e7d8ffa 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -33,19 +33,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --upgrade pip && \ pip install -r requirements-build.txt -# install oneDNN -RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git - -RUN --mount=type=cache,target=/root/.cache/ccache \ - cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ - -DONEDNN_BUILD_DOC=OFF \ - -DONEDNN_BUILD_EXAMPLES=OFF \ - -DONEDNN_BUILD_TESTS=OFF \ - -DONEDNN_BUILD_GRAPH=OFF \ - -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ - -DONEDNN_ENABLE_PRIMITIVE=MATMUL && \ - cmake --build ./oneDNN/build --target install --config Release - FROM cpu-test-1 AS build WORKDIR /workspace/vllm diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index bc5f24d3f591c..7237d246ddf55 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,5 +1,8 @@ +include(FetchContent) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 17) # # Define environment variables for special configurations @@ -82,15 +85,40 @@ else() message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") endif() +# +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) +# +if (AVX512_FOUND AND NOT AVX512_DISABLED) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.5.3 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + FetchContent_MakeAvailable(oneDNN) + + list(APPEND LIBS dnnl) +endif() + message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") list(APPEND LIBS numa) -# Appending the dnnl library for the AVX2 and AVX512, as it is not utilized by Power architecture. -if (AVX2_FOUND OR AVX512_FOUND) - list(APPEND LIBS dnnl) -endif() - # # _C extension # diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 5b1d3d6442b2b..a325153b470cc 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec { void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } }; +#ifdef __AVX512F__ +struct INT32Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512i reg; + int32_t values[VEC_ELEM_NUM]; + }; + + __m512i reg; + + explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {} + + void save(int32_t* ptr) const { + _mm512_storeu_epi32(ptr, reg); + } + + void save(int32_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_epi32(ptr, mask, reg); + } +}; +#endif + #ifdef __AVX512F__ struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; @@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(__m512 data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - explicit FP32Vec16(const FP32Vec4 &data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( @@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16 &v) + : reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} + FP32Vec16 operator*(const FP32Vec16 &b) const { return FP32Vec16(_mm512_mul_ps(reg, b.reg)); } @@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); } + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(_mm512_min_ps(reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); + } + FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); } @@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec { float reduce_max() const { return _mm512_reduce_max_ps(reg); } + float reduce_min() const { return _mm512_reduce_min_ps(reg); } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 2d7abe6145fee..b493fd793818a 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -5,25 +5,29 @@ namespace { template struct KernelVecType { using load_vec_type = void; + using azp_adj_load_vec_type = void; using cvt_vec_type = void; }; template <> struct KernelVecType { using load_vec_type = vec_op::FP32Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; using cvt_vec_type = vec_op::FP32Vec16; }; template <> struct KernelVecType { using load_vec_type = vec_op::BF16Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; using cvt_vec_type = vec_op::FP32Vec16; }; #ifdef __AVX512F__ -template +template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int num_tokens, + const float* scale, const int32_t* azp, + const int num_tokens, const int hidden_size) { using load_vec_t = typename KernelVecType::load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; @@ -37,62 +41,110 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const cvt_vec_t i8_min_vec(i8_min); const cvt_vec_t i8_max_vec(i8_max); + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { int j = 0; for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); vec_op::INT8Vec16 elems_int8(elems_fp32); elems_int8.save(output + i * hidden_size + j); } load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_fp32 = elems_fp32 * inv_scale; - if (j + vec_elem_num == hidden_size) { - elems_int8.save(output + i * hidden_size + j); - } else { - elems_int8.save(output + i * hidden_size + j, hidden_size - j); + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); } } -template +template void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, const int num_tokens, + float* scale, int32_t* azp, + const int num_tokens, const int hidden_size) { using load_vec_t = typename KernelVecType::load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_abs(0.0); + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); { int j = 0; for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - max_abs = max_abs.max(elems_fp32.abs()); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } } load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); if (j + vec_elem_num == hidden_size) { - max_abs = max_abs.max(elems_fp32.abs()); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } } else { - max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } } } - float scale_val = max_abs.reduce_max() / 127.0f; - scale[i] = scale_val; + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); { int j = 0; @@ -100,6 +152,11 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); vec_op::INT8Vec16 elems_int8(elems_fp32); elems_int8.save(output + i * hidden_size + j); } @@ -107,34 +164,111 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); elems_fp32 = (elems_fp32 * inv_scale); - vec_op::INT8Vec16 elems_int8(elems_fp32); - if (j + vec_elem_num == hidden_size) { - elems_int8.save(output + i * hidden_size + j); - } else { - elems_int8.save(output + i * hidden_size + j, hidden_size - j); + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); } } } -template -void dynamic_output_scale_impl(const float* input, scalar_t* output, - const float* scale, const scalar_t* bias, - const int num_tokens, const int hidden_size) { +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { int j = 0; - cvt_vec_t token_scale_vec(scale[i]); + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { cvt_vec_t elems_fp32(input + i * hidden_size + j); elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + if constexpr (Bias) { load_vec_t bias_vec(bias + j); cvt_vec_t bias_vec_fp32(bias_vec); @@ -148,6 +282,19 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output, cvt_vec_t elems_fp32(input + i * hidden_size + j); elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + if constexpr (Bias) { load_vec_t bias_vec(bias + j); cvt_vec_t bias_vec_fp32(bias_vec); @@ -155,32 +302,41 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output, } load_vec_t elems_out(elems_fp32); - - if (j + vec_elem_num == hidden_size) { - elems_out.save(output + i * hidden_size + j); - } else { - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } + elems_out.save(output + i * hidden_size + j, hidden_size - j); } } #else template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int num_tokens, + const float* scale, const int32_t* azp, + const int num_tokens, const int hidden_size) { TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") } template void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, const int num_tokens, + float* scale, int32_t* azp, + const int num_tokens, const int hidden_size) { TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") } +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") +} + template -void dynamic_output_scale_impl() { - TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.") +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_with_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") } #endif } // namespace @@ -214,39 +370,52 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major bias->dim() == 1); } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] { + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { if (a_scales.numel() != 1) { // per-token // Note: oneDNN doesn't support per-token activation quantization + // Ideally we want to fuse the GEMM and the scale procedure with oneDNN + // JIT, the intermediate data is cached in registers or L1. But for now + // the oneDNN GEMM code generation only supports two quantization + // patterns: per-tensor or per-output-channel of weight. + // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * + // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN + // GEMM, then the per-token scale (and bias) is applied with the epilogue + // C=s_a * C_inter + bias. torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - DNNLPrimitiveHelper::gemm_s8s8_jit( + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), (void*)(0), a.size(0), b.size(1), - a.size(1), (float*)(0), b_scales.data_ptr(), 0, - b_scales.numel()); + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); if (bias.has_value()) { - dynamic_output_scale_impl( + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), bias->data_ptr(), c.size(0), - c.size(1)); + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); } else { - dynamic_output_scale_impl( + // Compute C=s_a * C_inter + dynamic_quant_epilogue( tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), (scalar_t*)(0), c.size(0), c.size(1)); + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); } } else { // per-tensor if (bias.has_value()) { + // Compute C=s_a * s_b * (A@B) + bias DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), c.data_ptr(), bias->data_ptr(), a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); } else { - DNNLPrimitiveHelper::gemm_s8s8_jit( + // Compute C=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), c.data_ptr(), - (void*)(0), a.size(0), b.size(1), a.size(1), + nullptr, a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); } @@ -254,6 +423,127 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major }); } +void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const torch::Tensor& azp_adj, // [OC] + const c10::optional& azp, // [1] or [M] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_azp only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); + } + if (azp) { + TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); + } + TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); + + // azp & bias types + TORCH_CHECK(azp_adj.dtype() == torch::kInt32); + TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); + TORCH_CHECK(!bias || bias->dtype() == c.dtype(), + "currently bias dtype must match output dtype ", c.dtype()); + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C_inter=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), bias->data_ptr(), + a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), + b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); + } else { + // Compute C_inter=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + + // Compute C=C_inter - s_a * s_b * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } else { + // Per-Tensor + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } + } + }); +} + // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] @@ -263,15 +553,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } }); } @@ -284,14 +581,20 @@ void dynamic_scaled_int8_quant( CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } }); } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ab697e3e6aef7..03beefbc6de7d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b_scales, const c10::optional& bias); +void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const torch::Tensor& azp_adj, + const c10::optional& azp, + const c10::optional& bias); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif } diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index c8947beb34942..f544325a0776c 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -59,20 +59,6 @@ Build from source $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu -- Third, build and install oneDNN library from source: - -.. code-block:: console - - $ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git - $ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ - -DONEDNN_BUILD_DOC=OFF \ - -DONEDNN_BUILD_EXAMPLES=OFF \ - -DONEDNN_BUILD_TESTS=OFF \ - -DONEDNN_BUILD_GRAPH=OFF \ - -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ - -DONEDNN_ENABLE_PRIMITIVE=MATMUL - $ cmake --build ./oneDNN/build --target install --config Release - - Finally, build and install vLLM CPU backend: .. code-block:: console