[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)

Co-authored-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
ElizaWszola 2024-10-04 20:34:44 +02:00 committed by GitHub
parent 0dcc8cbe5a
commit 05d686432f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 969 additions and 223 deletions

View File

@ -433,6 +433,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_moe_ops.cu") "csrc/moe/marlin_moe_ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(

View File

@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
using FragZP = Vec<half2, 4>;
// Predicated asynchronous global->shared copy; used for inputs A where we apply // Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16. // predication to handle batchsizes that are not multiples of 16.
@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
return frag_b; return frag_b;
} }
template <>
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
// Given 2 floats multiply by 2 scales (halves) __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
__device__ inline void scale_float(float* c, FragS& s) { half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
__half* s_ptr = reinterpret_cast<__half*>(&s); frag_b[0] = __hsub2(frag_b[0], zp);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); frag_b[1] = __hsub2(frag_b[1], zp);
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
} }
// Same as above, but for act_order (each K is multiplied individually) // Same as above, but for act_order (each K is multiplied individually)
@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4); frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
} }
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) { __device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
__device__ inline void MarlinMoESingle( __device__ void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C, // fp16 output buffer of shape mxn
@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle(
int tb_n_warps = thread_n_blocks / 4; int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
constexpr int sorted_sh_stride = threads; // Zero-points sizes/strides
constexpr int sorted_gl_stride = threads; int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle(
int s_sh_wr = threadIdx.x; int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
int sh_first_group_id = -1; int sh_first_group_id = -1;
int sh_num_groups = -1; int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32; constexpr int sh_max_num_groups = 32;
int shs_size;
if constexpr (has_act_order)
shs_size = sh_max_num_groups * s_sh_stride + threads;
else
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4* sh_a = sh; int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int* sh_sorted = (int*)(sh_s + shs_size); int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Precompute which thread should not read memory in which iterations; this is // Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or // needed if there are more threads than required for a certain tilesize or
@ -527,6 +600,8 @@ __device__ inline void MarlinMoESingle(
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle(
} }
} }
} }
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
} }
} }
// Insert a fence even when we are winding down the pipeline to ensure that // Insert a fence even when we are winding down the pipeline to ensure that
@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle(
cp_async_fence(); cp_async_fence();
}; };
// TODO we are currently hitting illegal memory accesses when fetching auto fetch_zp_to_shared = [&]() {
// sorted_ids to shared data: fix this if (zp_sh_wr_pred) {
auto fetch_sorted_ids_to_shared = [&]() { cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
const int mpt = ceildiv(prob_m, threads);
for (int i = 0; i < mpt; i++) {
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
}
} }
}; };
@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle(
} }
}; };
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) { auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = frag_qzp[k % 2][1];
}
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle(
FragB frag_b0 = dequant<w_type_id>(b_quant_0); FragB frag_b0 = dequant<w_type_id>(b_quant_0);
FragB frag_b1 = dequant<w_type_id>(b_quant_1); FragB frag_b1 = dequant<w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp(frag_b0, frag_zp[j], 0);
}
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle(
} }
} }
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1 // Apply scale to frag_b1
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle(
// Start global fetch and register load pipelines. // Start global fetch and register load pipelines.
auto start_pipes = [&]() { auto start_pipes = [&]() {
// TODO re-enable after fixing this function
// fetch_sorted_ids_to_shared();
// __syncthreads();
#pragma unroll #pragma unroll
for (int i = 0; i < stages - 1; i++) { for (int i = 0; i < stages - 1; i++) {
@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle(
} }
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
} }
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters); fetch_to_shared(i, i, i < slice_iters);
} }
@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle(
init_same_group(0); init_same_group(0);
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0); fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1); a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1);
}; };
@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle(
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe); fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) { if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages); slice_iters >= stages);
@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle(
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} }
start_pipes(); start_pipes();
} }
} }
@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
@ -1261,6 +1444,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
@ -1309,29 +1494,29 @@ __global__ void MarlinMoE(
if (max_block == 1) { if (max_block == 1) {
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else if (max_block == 2) { } else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else if (max_block == 3) { } else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
} else { } else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks, MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>( stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block); current_m_block);
@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale // with a separate quantization scale
> >
@ -1358,6 +1544,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets, const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
@ -1374,7 +1562,6 @@ __global__ void MarlinMoE(
int current_m_block, // current m block to start kernel computation from int current_m_block, // current m block to start kernel computation from
int max_par, // maximum parallelism int max_par, // maximum parallelism
int cfg_max_m_blocks // upper bound on m blocks int cfg_max_m_blocks // upper bound on m blocks
) { ) {
// Marlin is not implemented yet for SM < 8.0 // Marlin is not implemented yet for SM < 8.0
assert(false); assert(false);
@ -1389,37 +1576,41 @@ __global__ void MarlinMoE(
const int USER_THREADS = const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n 256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory const int STAGES = 4; // 4 pipeline stages fit into shared memory
// const int SHARED_MEM =
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static constexpr int min_thread_n = 64; static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64; static constexpr int min_thread_k = 64;
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ #define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
GROUP_BLOCKS, NUM_THREADS) \ HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
num_threads == NUM_THREADS) { \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \ STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \ replicate_input, apply_weights, m_block, max_par, \
cfg_max_m_blocks); \ cfg_max_m_blocks); \
} }
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
} // namespace marlin_moe } // namespace marlin_moe

View File

@ -0,0 +1,31 @@
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -0,0 +1,20 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe

View File

@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)

View File

@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe } // namespace marlin_moe

View File

@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)

View File

@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} }

View File

@ -30,6 +30,7 @@
#include "core/registration.h" #include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T> template <typename T>
inline std::string str(T x) { inline std::string str(T x) {
@ -157,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{128, 64, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X {64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N {64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
}; };
thread_config_t large_batch_thread_configs[] = { thread_config_t large_batch_thread_configs[] = {
@ -167,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{128, 128, 256}, // Reduce N 2X, increase K 2X {128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K {64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X {128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
}; };
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
@ -313,26 +316,27 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
} }
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ else if (KERNEL_FUNCTION( \
has_act_order, group_blocks, num_threads, blocks, \ q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ group_blocks, num_threads, blocks, max_shared_mem, stream, \
sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
expert_offsets_ptr, num_groups, expert_idx, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
locks, replicate_input, apply_weights, m_block, \ replicate_input, apply_weights, m_block, max_par, \
max_par, exec_cfg.max_m_blocks)) { \ exec_cfg.max_m_blocks)) { \
} }
void marlin_mm_moe(const void* A, const void* B, void* C, void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights, const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx, const void* topk_ids, const void* s, void* zp,
const void* perm, void* a_tmp, void* expert_offsets, const void* g_idx, const void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, void* workspace, void* expert_offsets, int prob_m, int prob_n, int prob_k,
vllm::ScalarType const& q_type, bool has_act_order, void* workspace, vllm::ScalarType const& q_type,
bool is_k_full, int num_groups, int group_size, bool has_act_order, bool is_k_full, bool has_zp,
int num_experts, int topk, int moe_block_size, int dev, int num_groups, int group_size, int num_experts, int topk,
cudaStream_t stream, int thread_k, int thread_n, int sms, int moe_block_size, int dev, cudaStream_t stream,
int max_par, bool replicate_input, bool apply_weights) { int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -436,6 +440,8 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
const float* topk_weights_ptr = (const float*)topk_weights; const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids; const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace; int* locks = (int*)workspace;
@ -456,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
} }
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" + str(prob_n) + ", " + str(prob_k) + "]" +
@ -475,13 +482,21 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& b_zeros, const torch::Tensor& g_idx,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, const torch::Tensor& perm, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t num_experts, int64_t topk, int64_t moe_block_size, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
bool replicate_input, bool apply_weights) { int64_t moe_block_size, bool replicate_input, bool apply_weights) {
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits(); int pack_factor = 32 / b_q_type->size_bits();
@ -543,14 +558,27 @@ torch::Tensor marlin_gemm_moe(
} }
} }
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe( marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, num_experts, topk, moe_block_size, dev,
thread_n, sms, max_par, replicate_input, apply_weights); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c; return c;
} }

View File

@ -12,7 +12,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)" "int moe_block_size, bool replicate_input, bool apply_weights)"

View File

@ -2260,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups); " is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1), "b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor); " is not size_n / pack_factor = ", size_n / pack_factor);
} }

View File

@ -0,0 +1,160 @@
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(
a,
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2

View File

@ -2,16 +2,14 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
from typing import List
import pytest import pytest
import torch import torch
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe) fused_marlin_moe, single_marlin_moe)
@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype):
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64]) @pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@ -159,9 +116,6 @@ def test_fused_marlin_moe(
): ):
seed_everything(7) seed_everything(7)
if topk > e:
return
# Filter act_order # Filter act_order
if act_order: if act_order:
if group_size == -1: if group_size == -1:
@ -241,15 +195,15 @@ def test_fused_marlin_moe(
a, a,
qweight1, qweight1,
qweight2, qweight2,
scales1,
scales2,
score, score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=scales1, g_idx1=g_idx1,
w2_scale=scales2, g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
@ -280,9 +234,13 @@ def test_fused_marlin_moe(
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe, opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids, (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False)) 2 * n, k, True, e, topk, block_size_m, True, False))
@ -291,7 +249,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64]) @pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply(
num_bits: int, num_bits: int,
is_k_full: bool, is_k_full: bool,
): ):
if topk > e:
return
# Filter act_order # Filter act_order
if act_order: if act_order:
@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply(
qweight, qweight,
scales, scales,
score, score,
g_idx,
sort_indices,
topk, topk,
renormalize=False, renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 assert compute_max_diff(marlin_output, torch_output) < 1e-2

View File

@ -12,6 +12,7 @@ import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad) make_tensor_with_pad)
@ -974,6 +975,50 @@ def fp8_allclose(
equal_nan=equal_nan)).item()) equal_nan=equal_nan)).item())
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
# A special version of op check that has a restricted default set of test_utils # A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types. # and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,

View File

@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main

View File

@ -1,7 +1,20 @@
#!/bin/bash #!/bin/bash
SUCCESS=0 SUCCESS=0
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" while getopts "c:" OPT; do
case ${OPT} in
c )
CONFIG="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do do

View File

@ -568,6 +568,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output return output
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output
def gptq_marlin_gemm(a: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_scales: torch.Tensor,
@ -828,11 +842,12 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
sorted_ids: torch.Tensor, sorted_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType, perm: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int, b_q_type: ScalarType, size_m: int, size_n: int,
is_k_full: bool, num_experts: int, topk: int, size_k: int, is_k_full: bool, num_experts: int,
moe_block_size: int, replicate_input: bool, topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor: apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n), return torch.empty((size_m, topk, size_n),
dtype=a.dtype, dtype=a.dtype,

View File

@ -10,15 +10,24 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe( def single_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
@ -34,10 +43,12 @@ def single_marlin_moe(
- scales (torch.Tensor): The quantization scales. - scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx (torch.Tensor): The act_order indices. - g_idx (Optional[torch.Tensor]): Optional act_order indices.
- perm (torch.Tensor): The act_order input permutation. - sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
@ -79,16 +90,34 @@ def single_marlin_moe(
max_workspace_size = (N // 64) * 16 max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device=hidden_states.device,
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 has_zero_point = w_zeros is not None
if num_bits == 4 else scalar_types.uint8b128) if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
if sort_indices is None:
sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
block_size_m, True, False) is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@ -97,16 +126,18 @@ def fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
perm1: torch.Tensor,
perm2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
@ -118,21 +149,22 @@ def fused_marlin_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices. - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation. - sort_indices1 (Optional[torch.Tensor]): The first act_order input
- perm2 (torch.Tensor): The second act_order input permutation. permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights. - topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements. - topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
Returns: Returns:
@ -152,6 +184,20 @@ def fused_marlin_moe(
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8] assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
@ -172,14 +218,42 @@ def fused_marlin_moe(
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 if has_no_zp:
if num_bits == 4 else scalar_types.uint8b128) w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
@ -194,10 +268,11 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
w1_zeros,
g_idx1, g_idx1,
perm1, sort_indices1,
workspace, workspace,
scalar_type, scalar_type1,
M, M,
2 * N, 2 * N,
K, K,
@ -218,10 +293,11 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w2_scale, w2_scale,
w2_zeros,
g_idx2, g_idx2,
perm2, sort_indices2,
workspace, workspace,
scalar_type, scalar_type2,
M, M,
K, K,
N, N,

View File

@ -1,16 +1,21 @@
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@ -35,12 +40,13 @@ class AWQMarlinConfig(QuantizationConfig):
self.group_size = group_size self.group_size = group_size
self.has_zp = has_zp self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
if weight_bits not in self.TYPE_MAP: if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {weight_bits}. " raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}") f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[weight_bits] self.quant_type = self.TYPE_MAP[self.weight_bits]
verify_marlin_supported(self.quant_type, verify_marlin_supported(self.quant_type,
group_size=self.group_size, group_size=self.group_size,
@ -98,10 +104,12 @@ class AWQMarlinConfig(QuantizationConfig):
return None return None
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMarlinLinearMethod"]: prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
@ -272,3 +280,181 @@ class AWQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
bias=bias) bias=bias)
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
True,
"quant_method":
FusedMoeWeightScaleSupported.GROUP.value,
})
w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
)

View File

@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits, router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=layer.w13_weight_scale, g_idx1=layer.w13_g_idx,
w2_scale=layer.w2_weight_scale, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits, num_bits=self.num_bits,
) )

View File

@ -557,14 +557,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits, router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=layer.w13_scales, g_idx1=layer.w13_g_idx,
w2_scale=layer.w2_scales, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype) ).to(orig_dtype)

View File

@ -208,6 +208,7 @@ def marlin_moe_permute_scales(
device=s.device, device=s.device,
dtype=s.dtype, dtype=s.dtype,
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output return output
@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp return marlin_zp
def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
device=q_zp_packed.device,
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
num_bits)
return output
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,

View File

@ -23,7 +23,9 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
]
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported