mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:40:44 +08:00
[Kernel] add bfloat16 support for gptq marlin kernel (#4788)
This commit is contained in:
parent
5c342570d7
commit
99caa49106
@ -20,6 +20,11 @@
|
||||
*/
|
||||
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "gptq_marlin_dtypes.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
|
||||
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
||||
|
||||
@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
||||
int4 *__restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int block_rows) {}
|
||||
|
||||
template <const int num_bits, // number of bits used for weights
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
@ -72,31 +78,36 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
|
||||
#else
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
|
||||
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
|
||||
FragC &frag_c) {
|
||||
template <typename scalar_t>
|
||||
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
||||
const typename ScalarType<scalar_t>::FragB &frag_b,
|
||||
typename ScalarType<scalar_t>::FragC &frag_c) {
|
||||
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
|
||||
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
|
||||
float *c = reinterpret_cast<float *>(&frag_c);
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
||||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
||||
} else {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
||||
template <typename scalar_t>
|
||||
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
|
||||
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_4bit(int q) {
|
||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
FragB frag_b;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
||||
*reinterpret_cast<const half2 *>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
|
||||
@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
__device__ inline FragB dequant_8bit(int q) {
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC308C308;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162 *>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
|
||||
// Reference:
|
||||
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
FragB frag_b;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
||||
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
|
||||
@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388736.f;
|
||||
fp32_intermediates[1] -= 8388736.f;
|
||||
fp32_intermediates[2] -= 8388736.f;
|
||||
fp32_intermediates[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
||||
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
||||
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale(typename ScalarType<scalar_t>::FragB &frag_b,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s, int i) {
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
|
||||
FragS &frag_s_3, FragS &frag_s_4, int i) {
|
||||
__half2 s_val_1_2;
|
||||
s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i];
|
||||
s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i];
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s_1,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s_2,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s_3,
|
||||
typename ScalarType<scalar_t>::FragS &frag_s_4,
|
||||
int i) {
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 s_val_1_2;
|
||||
s_val_1_2.x = reinterpret_cast<scalar_t *>(&frag_s_1)[i];
|
||||
s_val_1_2.y = reinterpret_cast<scalar_t *>(&frag_s_2)[i];
|
||||
|
||||
__half2 s_val_3_4;
|
||||
s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i];
|
||||
s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i];
|
||||
scalar_t2 s_val_3_4;
|
||||
s_val_3_4.x = reinterpret_cast<scalar_t *>(&frag_s_3)[i];
|
||||
s_val_3_4.y = reinterpret_cast<scalar_t *>(&frag_s_4)[i];
|
||||
|
||||
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
|
||||
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
||||
}
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
__device__ inline void scale_float(float *c, FragS &s) {
|
||||
__half *s_ptr = reinterpret_cast<__half *>(&s);
|
||||
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
|
||||
scalar_t *s_ptr = reinterpret_cast<scalar_t *>(&s);
|
||||
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
template <const int num_bits, // number of bits used for weights
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
@ -323,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// ensures good utilization of all SMs for many kinds of shape and GPU
|
||||
// configurations, while requiring as few slow global cross-threadblock
|
||||
// reductions as possible.
|
||||
using Dtype = ScalarType<scalar_t>;
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
using FragA = typename ScalarType<scalar_t>::FragA;
|
||||
using FragB = typename ScalarType<scalar_t>::FragB;
|
||||
using FragC = typename ScalarType<scalar_t>::FragC;
|
||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++)
|
||||
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||
ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
|
||||
#pragma unroll
|
||||
@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
int b_quant = frag_b_quant[k % 2][0][j];
|
||||
int b_quant_shift = b_quant >> 8;
|
||||
|
||||
frag_b0 = dequant_4bit(b_quant);
|
||||
frag_b1 = dequant_4bit(b_quant_shift);
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
|
||||
} else {
|
||||
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
|
||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
|
||||
frag_b0 = dequant_8bit(b_quant_0);
|
||||
frag_b1 = dequant_8bit(b_quant_1);
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
|
||||
} else {
|
||||
if constexpr (group_blocks != -1) {
|
||||
scale(frag_b0, frag_s[k % 2][j], 0);
|
||||
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
|
||||
|
||||
} else {
|
||||
if constexpr (group_blocks != -1) {
|
||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
||||
scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
||||
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
||||
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
for (int j = 0; j < 2 * 4; j++) {
|
||||
reinterpret_cast<float *>(
|
||||
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
||||
__half2float(reinterpret_cast<__half *>(&c_red)[j]);
|
||||
Dtype::num2float(reinterpret_cast<scalar_t *>(&c_red)[j]);
|
||||
}
|
||||
}
|
||||
if (!last) {
|
||||
int4 c;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2 * 4; j++) {
|
||||
reinterpret_cast<__half *>(&c)[j] =
|
||||
__float2half(reinterpret_cast<float *>(
|
||||
reinterpret_cast<scalar_t *>(&c)[j] =
|
||||
Dtype::float2num(reinterpret_cast<float *>(
|
||||
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
||||
}
|
||||
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
||||
@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS &s) {
|
||||
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
||||
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
((half2 *)sh)[idx] = res;
|
||||
((scalar_t2 *)sh)[idx] = res;
|
||||
};
|
||||
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
|
||||
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
||||
scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
||||
@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
||||
void *g_idx, void *perm, void *a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void *workspace, int num_bits,
|
||||
@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
gptq_marlin::marlin_mm_f16i4(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
|
||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
gptq_marlin::marlin_mm_f16i4<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
|
||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
|
||||
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, gptq_marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
62
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
Normal file
62
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
Normal file
@ -0,0 +1,62 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "gptq_marlin.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
|
||||
static __device__ float inline num2float(const half x) { return __half2float(x); }
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
using FragA = Vec<nv_bfloat162, 4>;
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
|
||||
#endif
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -14,6 +14,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
@ -52,7 +53,7 @@ MODELS = [
|
||||
@pytest.mark.skipif(gptq_marlin_not_supported,
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
@ -76,11 +77,15 @@ def test_models(
|
||||
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
del gptq_marlin_model
|
||||
_ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
|
||||
|
||||
# Run gptq.
|
||||
# The naive gptq kernel doesn't support bf16 yet.
|
||||
# Here we always compare fp16/bf16 gpt marlin kernel
|
||||
# to fp16 gptq kernel.
|
||||
gptq_model = vllm_runner(model_name=model_name,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
dtype="half",
|
||||
quantization="gptq",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1)
|
||||
|
||||
@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
group_size = input_size
|
||||
|
||||
# Validate dtype
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
if params_dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(f"The params dtype must be float16 "
|
||||
f"or bfloat16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user