[Kernel] add bfloat16 support for gptq marlin kernel (#4788)

This commit is contained in:
Jinzhen Lin 2024-05-16 21:55:29 +08:00 committed by GitHub
parent 5c342570d7
commit 99caa49106
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 245 additions and 72 deletions

View File

@ -20,6 +20,11 @@
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template <typename T> inline std::string str(T x) { return std::to_string(x); }
@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int4 *__restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
template <const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
@ -72,31 +78,36 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
#else
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
FragC &frag_c) {
template <typename scalar_t>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
const typename ScalarType<scalar_t>::FragB &frag_b,
typename ScalarType<scalar_t>::FragC &frag_c) {
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
float *c = reinterpret_cast<float *>(&frag_c);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
template <typename scalar_t>
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant_4bit(int q) {
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
return frag_b;
}
__device__ inline FragB dequant_8bit(int q) {
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162 *>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
// Reference:
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
template <typename scalar_t>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB &frag_b,
typename ScalarType<scalar_t>::FragS &frag_s, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
// Same as above, but for act_order (each K is multiplied individually)
__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
FragS &frag_s_3, FragS &frag_s_4, int i) {
__half2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i];
template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
typename ScalarType<scalar_t>::FragS &frag_s_1,
typename ScalarType<scalar_t>::FragS &frag_s_2,
typename ScalarType<scalar_t>::FragS &frag_s_3,
typename ScalarType<scalar_t>::FragS &frag_s_4,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t *>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t *>(&frag_s_2)[i];
__half2 s_val_3_4;
s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i];
s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i];
scalar_t2 s_val_3_4;
s_val_3_4.x = reinterpret_cast<scalar_t *>(&frag_s_3)[i];
s_val_3_4.y = reinterpret_cast<scalar_t *>(&frag_s_4)[i];
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float *c, FragS &s) {
__half *s_ptr = reinterpret_cast<__half *>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
template <typename scalar_t>
__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
scalar_t *s_ptr = reinterpret_cast<scalar_t *>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
}
}
template <const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
@ -323,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
using FragA = typename ScalarType<scalar_t>::FragA;
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
constexpr int pack_factor = 32 / num_bits;
@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit(b_quant);
frag_b1 = dequant_4bit(b_quant_shift);
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
} else {
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit(b_quant_0);
frag_b1 = dequant_8bit(b_quant_1);
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
} else {
if constexpr (group_blocks != -1) {
scale(frag_b0, frag_s[k % 2][j], 0);
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
}
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
} else {
if constexpr (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
}
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float *>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
__half2float(reinterpret_cast<__half *>(&c_red)[j]);
Dtype::num2float(reinterpret_cast<scalar_t *>(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<__half *>(&c)[j] =
__float2half(reinterpret_cast<float *>(
reinterpret_cast<scalar_t *>(&c)[j] =
Dtype::float2num(reinterpret_cast<float *>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS &s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
res = __hmul2(res, s[0]);
}
((half2 *)sh)[idx] = res;
((scalar_t2 *)sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
void *g_idx, void *perm, void *a_tmp, int prob_m,
int prob_n, int prob_k, void *workspace, int num_bits,
@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
gptq_marlin::marlin_mm_f16i4(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
#endif

View File

@ -0,0 +1,62 @@
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace gptq_marlin {
template <typename scalar_t>
class ScalarType {
};
template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
static __device__ float inline num2float(const half x) { return __half2float(x); }
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
};
template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
#endif
};
}
#endif

View File

@ -14,6 +14,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
from .utils import check_logprobs_close
@ -52,7 +53,7 @@ MODELS = [
@pytest.mark.skipif(gptq_marlin_not_supported,
reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
@ -76,11 +77,15 @@ def test_models(
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
del gptq_marlin_model
_ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
# Run gptq.
# The naive gptq kernel doesn't support bf16 yet.
# Here we always compare fp16/bf16 gpt marlin kernel
# to fp16 gptq kernel.
gptq_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
dtype="half",
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1)

View File

@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(f"The params dtype must be float16 "
f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)