diff --git a/CMakeLists.txt b/CMakeLists.txt index 093330caa4f9..5c1a200d1899 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -351,6 +351,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") + set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) @@ -364,7 +366,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" CUDA_ARCHS "${MARLIN_ARCHS}") + set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu" + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" @@ -854,6 +859,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_gencode_flags_for_srcs( SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index f73d0511e01f..975d10f2e92e 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: a=bt.a, c=None, b_q_weight=w_q, + b_bias=None, b_scales=w_s, global_scale=None, b_zeros=w_zp, diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index d0f85e23609b..68a8750f583b 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = + ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 49f33718a21e..698deb107cc0 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -77,6 +78,7 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 + # mxfp4 only supports group_size == 32 if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 @@ -89,9 +91,22 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 537282aba8c8..6190f7ee21ec 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,23 +7,25 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ - const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); @@ -365,6 +379,7 @@ __global__ void Marlin( const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); + const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; @@ -475,7 +490,7 @@ __global__ void Marlin( for (int i = 0; i < 4; i++) { int idx = tid4 * 4 + i; idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { sh_block_topk_weights[idx] = __hmul2( global_scale, Dtype::num2num2(Dtype::float2num( topk_weights_ptr[sh_block_sorted_ids[idx]]))); @@ -513,7 +528,7 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { uint16_t val = scale2_ptr[expert_id]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -526,6 +541,9 @@ __global__ void Marlin( if constexpr (has_act_order) { g_idx += (expert_id - old_expert_id) * prob_k; } + if (has_bias) { + b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride; + } read_moe_block_data(block_id); }; @@ -721,7 +739,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -734,6 +752,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -793,7 +823,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh_new; int4* sh_red = sh_new; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -803,9 +845,9 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = - moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int shm_size_used = moe_block_size + + stages * (g_idx_stage + zp_sh_stage) + + sh_s_size + sh_b_red_bias_size; // all remaining shared memory is used to cache A (input) // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` @@ -816,7 +858,8 @@ __global__ void Marlin( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -1065,10 +1108,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1281,9 +1329,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1566,7 +1614,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1592,7 +1640,7 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1601,14 +1649,27 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1626,19 +1687,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1805,6 +1872,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1867,11 +1942,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); int old_slice_row = slice_row; slice_row = 0; slice_col_par++; @@ -1904,6 +1988,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 2cff04f699b0..601e2aa6f991 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -51,8 +51,9 @@ __global__ void permute_cols_kernel( } // namespace marlin torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); + int tb_m = thread_m_blocks * 16; // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) @@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size + sh_block_meta_size; + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; return total_size; } @@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) - #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - BIGGROUP_GET_IF(vllm::kFE4M3fn) + NVFP4_GET_IF(vllm::kFE2M1f) - FP4_GET_IF(vllm::kFE2M1f) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - void* sorted_token_ids, void* expert_ids, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, bool use_atomic_add, bool use_fp32_reduce, + vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; @@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); // clang-format on } @@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm( num_groups = b_scales.size(1); torch::Tensor g_idx, perm, a_tmp; - ; if (g_idx_or_none.has_value() && perm_or_none.has_value()) { g_idx = g_idx_or_none.value(); perm = perm_or_none.value(); @@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm( b_zeros = torch::empty({0}, options); } bool has_zp = b_zeros.size(-1) > 0; - if (has_zp) { TORCH_CHECK( b_q_type == vllm::kU4 || b_q_type == vllm::kU8, @@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), - topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, - size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "moe_wna16_marlin_gemm only supports bfloat16 and float16"); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d96e082f6ef1..7e49f68f6243 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_q_weight, Tensor? b_bias_or_none," + "Tensor! b_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index ae0d6c0f2002..e8b0c302b202 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,11 +470,12 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } -template +template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, half2* frag_b) { int Out1 = (q & 0xFF00FF00) >> 1; ; q <<= 8; @@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales(int q, half2* frag_b) { }; template <> -__device__ inline void dequant_fp8_scales(int q, - nv_bfloat162* frag_b) { +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; @@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales(int q, frag_b[0] = *reinterpret_cast(&Out2); } +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + #endif } // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 18fb6c1a81f8..7576e0548abe 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME { TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " + "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " @@ -78,7 +79,8 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + # mxfp4 only supports group_size == 32 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: continue # other quantization methods don't support group_size = 16 if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: @@ -97,10 +99,23 @@ def generate_new_kernels(): # 4bit quantization and fp16 is_zp_float_list.append(True) + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this + continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + for is_zp_float in is_zp_float_list: template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 4a242f2050d5..cc30abcf0080 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int tb_m = thread_m_blocks * 16; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; - int sh_red_size = tb_m * (tb_n + 8); + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = + (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, sh_zp_size = sh_s_size / 2; } - int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + - sh_zp_size + sh_g_idx_size; + int total_size = + tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; return total_size; } @@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size <= max_shared_mem; + return cache_size + 512 <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : (std::is_same::value ? vllm::kFloat16 \ + : vllm::kBFloat16); \ + kernel = Marlin; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + #define NVFP4_GET_IF(W_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + + #define MXFP4_GET_IF(W_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ @@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - FP4_GET_IF(vllm::kFE2M1f) + NVFP4_GET_IF(vllm::kFE2M1f) BIGGROUP_GET_IF(vllm::kFE4M3fn) @@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, } FZP_GET_IF(vllm::kU4) } + if (std::is_same::value) { + if (false) { + } + MXFP4_GET_IF(vllm::kFE2M1f) + } return kernel; } @@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, int lda, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k_init, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, + void* s, void* s2, void* zp, void* g_idx, void* perm, + void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, + void* workspace, vllm::ScalarType const& q_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { @@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; @@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, - prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, torch::Tensor& b_scales, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f, - "global_scale can only be used for float4_e2m1f."); + TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), - "the global_scale parameter must be passed for float4_e2m1f."); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + "the global_scale parameter must be passed for nvfp4 format."); + } + + bool has_bias = b_bias_or_none.has_value(); + torch::Tensor b_bias; + if (has_bias) { + b_bias = b_bias_or_none.value(); + TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU"); + TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous"); + TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n"); + TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1"); + } else { + b_bias = torch::empty({0}, options); } torch::Tensor b_zeros; @@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm( if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, + a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { - scales_ptr = b_scales.data_ptr(); + if (group_size == 16) + scales_ptr = b_scales.data_ptr(); + else if (group_size == 32) + scales_ptr = b_scales.data_ptr(); + else + TORCH_CHECK(false, + "float4_e2m1f only supports group_size == 16 (NVFP4) ", + "and group_size == 32 (MXFP4)"); } else { scales_ptr = b_scales.data_ptr(); } marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index f92056589d20..bb454f6aff22 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -10,15 +10,18 @@ #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ scales_ptr, \ const uint16_t *__restrict__ scale2_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ - bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ + int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::FragZP; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); + if constexpr (w_type == vllm::kFE2M1f) { + static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || + s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - !is_int_type || + w_type == vllm::kFE4M3fn || + w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); scalar_t2 global_scale; - - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + // NVFP4 format requires global scale uint16_t val = scale2_ptr[0]; global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } @@ -589,7 +604,7 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + @@ -602,6 +617,18 @@ __global__ void Marlin( s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Zero-points have the same read layout as the scales // (without column-wise case) constexpr int num_col_threads = 8; @@ -670,7 +697,19 @@ __global__ void Marlin( constexpr int sh_b_size = stages * b_sh_stage; int4* sh_b = sh; int4* sh_red = sh; - int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + constexpr int sh_size_b_red_min = + (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = + sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4* sh_bias = sh + sh_size_b_red_min; + int4* sh_g_idx = sh + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -680,15 +719,13 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - // constexpr int shm_size_used = - // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + - // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 @@ -923,10 +960,15 @@ __global__ void Marlin( if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else { + } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + + k % 2]; } } } @@ -1139,9 +1181,9 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, - reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } @@ -1411,7 +1453,7 @@ __global__ void Marlin( // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() { + auto write_result = [&](bool last) { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -1438,7 +1480,7 @@ __global__ void Marlin( int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -1447,12 +1489,25 @@ __global__ void Marlin( if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - res = __hmul2(res, s[0]); + scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Dtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } + if (has_bias && last) { + scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Dtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; @@ -1470,19 +1525,25 @@ __global__ void Marlin( if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], - frag_s[j / 2][2 * (j % 2) + 0]); + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], - frag_s[j / 2][2 * (j % 2) + 1]); + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } else { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); } } c_sh_wr += 16 * (4 * c_sh_stride); @@ -1622,6 +1683,14 @@ __global__ void Marlin( } thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { @@ -1684,11 +1753,20 @@ __global__ void Marlin( } barrier_release(&locks[locks_off], last); } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + __syncthreads(); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); if (last || use_atomic_add) // only the last block in a slice actually writes the result - write_result(); + write_result(last); slice_row = 0; slice_col_par++; slice_col++; @@ -1706,6 +1784,7 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85b6abef00b0..8c207be083d8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -326,6 +326,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "Tensor? b_bias_or_none," "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 49c097718e30..b82c74a42ab3 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_permute_bias) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like) + rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -476,8 +478,11 @@ def marlin_moe_generate_valid_test_cases(): if quant_type == scalar_types.float8_e4m3fn and \ group_size not in [-1, 128]: return False - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return False + if quant_type == scalar_types.float4_e2m1f: + if group_size not in [16, 32]: + return False + if dtype == torch.float16 and group_size == 32: + return False if quant_type != scalar_types.float4_e2m1f and group_size == 16: return False @@ -520,31 +525,6 @@ def test_fused_marlin_moe( torch.cuda.manual_seed(0) has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] - if quant_type == scalar_types.float8_e4m3fn: - if group_size not in [-1, 128]: - return - if act_order: - return - - # Filter act_order - if act_order: - if quant_type == scalar_types.float8_e4m3fn: - return - if group_size == -1: - return - if group_size in (k, n): - return - if has_zp: - return - else: - if not is_k_full: - return - - if quant_type == scalar_types.float4_e2m1f and group_size != 16: - return - if quant_type != scalar_types.float4_e2m1f and group_size == 16: - return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -569,13 +549,19 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref1, qweight1, scales1, global_scale1 = \ - rand_marlin_weight_fp4_like(w1[i], group_size) + if group_size == 16: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_nvfp4_like(w1[i], group_size) + else: + w_ref1, qweight1, scales1 = \ + rand_marlin_weight_mxfp4_like(w1[i], group_size) + global_scale1 = None w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) - global_scale1_l.append(global_scale1) + if global_scale1 is not None: + global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( w1[i], group_size) @@ -620,13 +606,19 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref2, qweight2, scales2, global_scale2 = \ - rand_marlin_weight_fp4_like(w2[i], group_size) + if group_size == 16: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_nvfp4_like(w2[i], group_size) + else: + w_ref2, qweight2, scales2 = \ + rand_marlin_weight_mxfp4_like(w2[i], group_size) + global_scale2 = None w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) - global_scale2_l.append(global_scale2) + if global_scale2 is not None: + global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( w2[i], group_size) @@ -677,6 +669,8 @@ def test_fused_marlin_moe( a, qweight1, qweight2, + None, + None, scales1, scales2, score, @@ -698,6 +692,119 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.parametrize("m", [1, 256]) +def test_fused_marlin_moe_with_bias(m): + torch.cuda.manual_seed(0) + + e, topk = 32, 4 + n, k = 2048, 2048 + group_size = 128 + act_order = False + is_k_full = True + quant_type = scalar_types.uint4b8 + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 + b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 + + b_bias1_l = [] + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ + marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + b_bias1_l.append(marlin_permute_bias(b_bias1[i])) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + global_scale1 = None + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None + + b_bias2_l = [] + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ + marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + b_bias2_l.append(marlin_permute_bias(b_bias2[i])) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + global_scale2 = None + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, + b_bias2) + + marlin_output = torch.ops.vllm.fused_marlin_moe( + a, + qweight1, + qweight2, + marlin_bias1, + marlin_bias2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + quant_type_id=quant_type.id, + is_k_full=is_k_full) + + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + + def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 92914bd5cbba..1bd6713ce7fb 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] USE_ATOMIC_ADD_OPTS = [False, True] -USE_FP32_REDUCE_OPTS = [False, True] +USE_FP32_REDUCE_OPTS = [True] MARLIN_K_CHUNKS = [128] MARLIN_N_CHUNKS = [64, 256] @@ -202,17 +203,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_gptq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - act_order, - is_k_full, - use_atomic_add, - use_fp32_reduce, -): +@pytest.mark.parametrize("dtype", DTYPES) +def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, + mnk_factors, act_order, is_k_full, use_atomic_add, + use_fp32_reduce, dtype): m_factor, n_factor, k_factor = mnk_factors has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] @@ -231,14 +225,23 @@ def test_gptq_marlin_gemm( if size_k % group_size != 0: return - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) + a_input = rand_data((size_m, size_k), dtype) + b_weight = rand_data((size_k, size_n), dtype) if quant_type == scalar_types.float4_e2m1f: - if group_size != 16 or act_order: + if group_size not in [16, 32] or act_order: return - w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( - b_weight.T, group_size) + if group_size == 32 and dtype == torch.float16: + return + + if group_size == 16: + w_ref, marlin_q_w, marlin_s, marlin_s2 = \ + rand_marlin_weight_nvfp4_like(b_weight.T, group_size) + else: + w_ref, marlin_q_w, marlin_s = \ + rand_marlin_weight_mxfp4_like(b_weight.T, group_size) + marlin_s2 = None + g_idx = None sort_indices = None marlin_zp = None @@ -272,8 +275,8 @@ def test_gptq_marlin_gemm( workspace = marlin_make_workspace_new(w_ref.device) opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, - sort_indices, workspace, quant_type.id, a_input.shape[0], + (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp, + g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) @@ -282,6 +285,7 @@ def test_gptq_marlin_gemm( a_input, None, marlin_q_w, + None, marlin_s, marlin_s2, marlin_zp, @@ -418,6 +422,7 @@ def test_hqq_marlin_gemm( a_input, None, marlin_w_q, + None, marlin_s, None, marlin_zp, @@ -531,6 +536,7 @@ def test_marlin_gemm_subset_input(): a_input, None, marlin_q_w, + None, marlin_s, None, marlin_zp, @@ -555,6 +561,53 @@ def test_marlin_gemm_subset_input(): assert max_diff < 0.04 +@pytest.mark.parametrize("size_m", [1, 256]) +def test_marlin_gemm_with_bias(size_m): + quant_type = scalar_types.uint4b8 + group_size = 128 + + size_k, size_n = 1024, 2048 + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + b_bias = rand_data((size_n, )) * 10 + + marlin_bias = marlin_permute_bias(b_bias) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, False) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + workspace = marlin_make_workspace_new(a_input.device) + + output = ops.gptq_marlin_gemm( + a_input, + None, + marlin_q_w, + marlin_bias, + marlin_s, + None, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=True, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False, + ) + output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 + + def test_marlin_gemm_opcheck(): size_m = 2048 size_n = 4096 diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2e8febbdcf26..fa4125840a01 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1064,6 +1064,8 @@ def torch_experts( topk_weight: torch.Tensor, topk_ids: torch.Tensor, global_num_experts: int = -1, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -1108,8 +1110,13 @@ def torch_experts( if mask.sum(): if quant_dtype is None: tmp1 = a[mask] @ w1[i].transpose(0, 1) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) elif block_shape is not None: # block quantized assert (a_scale is not None and w1_scale is not None @@ -1117,6 +1124,8 @@ def torch_experts( tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( tmp2, a2_scale, quant_dtype, per_act_token_quant, @@ -1125,6 +1134,9 @@ def torch_experts( out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + tmp1.dtype) else: assert (a_scale is not None and w1_scale is not None and w2_scale is not None) @@ -1133,6 +1145,8 @@ def torch_experts( tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) tmp1 = (tmp1 @ w1_dq).to(out.dtype) + if b_bias1 is not None: + tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype) tmp2 = SiluAndMul()(tmp1).to(out.dtype) @@ -1144,6 +1158,9 @@ def torch_experts( tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) + if b_bias2 is not None: + out[mask] = out[mask] + b_bias2[i].view(1, -1).to( + out.dtype) if apply_router_weights_on_input: return out @@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor, w2: torch.Tensor, score: torch.Tensor, topk: int, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - expert_map) + b_bias1, b_bias2, expert_map) def torch_moe_single(a, w, score, topk): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 70605d3c5f52..a020b171e894 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -452,6 +452,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -1048,6 +1049,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], @@ -1062,7 +1064,7 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, global_scale, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, @@ -1540,7 +1542,9 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], - b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qweight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], @@ -1556,11 +1560,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, - perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, - topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, - b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, - use_fp32_reduce, is_zp_float) + input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros, + g_idx, perm, workspace, sorted_token_ids, expert_ids, + num_tokens_past_padded, topk_weights, moe_block_size, top_k, + mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k, + is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): diff --git a/vllm/envs.py b/vllm/envs.py index 145ec3495a0c..110bb542b120 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -122,6 +122,7 @@ if TYPE_CHECKING: VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_MXFP4_USE_MARLIN: Optional[bool] = None VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 @@ -182,6 +183,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: return int(value) +def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: + if value is None: + return None + return bool(int(value)) + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -906,6 +913,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + # Whether to use marlin kernel in mxfp4 quantization method + "VLLM_MXFP4_USE_MARLIN": + lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), + # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1988c73ba7e2..a49d41c18438 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -18,6 +18,8 @@ from vllm.utils import direct_register_custom_op def fused_marlin_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], w1_scale: torch.Tensor, w2_scale: torch.Tensor, gating_output: torch.Tensor, @@ -26,6 +28,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, + activation: Optional[str] = "silu", expert_map: Optional[torch.Tensor] = None, global_scale1: Optional[torch.Tensor] = None, global_scale2: Optional[torch.Tensor] = None, @@ -88,6 +91,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] + assert topk_weights.dtype == torch.float32 M, K = hidden_states.shape E = w1.shape[0] @@ -138,6 +142,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, hidden_states, intermediate_cache1, w1, + bias1, w1_scale, global_scale1, w1_zeros, @@ -161,8 +166,28 @@ def fused_marlin_moe(hidden_states: torch.Tensor, use_fp32_reduce=True, is_zp_float=False) - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) + elif activation == "swiglu_oai": + # NOTE: in gpt-oss, the gate_proj and up_proj is interleaved + # - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # - origin: gate, up = gate_up[..., :N], gate_up[..., N:] + + @torch.compile(dynamic=True) + def swiglu_oai(gate_up): + alpha = 1.702 + limit = 7.0 + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + return (up + 1) * glu + + intermediate_cache2 = swiglu_oai(intermediate_cache1) + else: + raise ValueError(f"Unsupported activation: {activation}. " + "Only silu and swiglu_oai activations are supported.") if expert_map is not None: intermediate_cache3.zero_() @@ -171,6 +196,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache2, intermediate_cache3, w2, + bias2, w2_scale, global_scale2, w2_zeros, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ddc02168e5c4..36e75825853e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -36,7 +36,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, - has_triton_kernels, is_torch_equal_or_newer, round_up) + round_up) from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): @@ -751,19 +751,11 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts # we padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - if not current_platform.is_device_capability(100): - if not is_torch_equal_or_newer("2.8.0"): - raise RuntimeError( - "Mxfp4 on non-blackwell requires torch >= 2.8.0") - if not has_triton_kernels(): - raise NotImplementedError( - "triton_kernels must be installed for " - "mxfp4 on non-blackwell") - if (current_platform.is_rocm() - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): - hidden_size = round_up(hidden_size, 256) + if (quant_config and quant_config.get_name() == "mxfp4" + and (current_platform.is_rocm() + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 6cf02658a94c..ed7ffb21e85a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, marlin_make_empty_g_idx, marlin_make_workspace_new, - marlin_moe_permute_scales, marlin_permute_scales, + marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -303,6 +303,9 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -469,6 +472,12 @@ class AWQMoEMethod(FusedMoEMethodBase): num_bits=self.quant_config.weight_bits) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -513,6 +522,8 @@ class AWQMoEMethod(FusedMoEMethodBase): x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c04f7c39a5f5..839942beaf40 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -324,6 +324,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -795,6 +797,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, @@ -1253,6 +1257,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): x, layer.w13_weight_packed, layer.w2_weight_packed, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9577fa025b70..5e107c799b9f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -983,6 +983,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9bed5e2e4889..3299221e3af3 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_dynamic_override, get_linear_quant_method, override_config) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, - marlin_make_workspace_new, marlin_moe_permute_scales, + marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -618,6 +618,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_scales", marlin_w2_scales) + if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) + + if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def apply( self, layer: torch.nn.Module, @@ -662,6 +668,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): x, layer.w13_qweight, layer.w2_qweight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), layer.w13_scales, layer.w2_scales, router_logits, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ee8a0e34b32e..8385ccac32a2 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - marlin_make_empty_g_idx, marlin_permute_scales) + marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack @@ -284,6 +284,9 @@ class HQQMarlinMethod(LinearMethodBase): layer.marlin_zeros = marlin_zp layer.marlin_scales = marlin_s + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply( self, layer: torch.nn.Module, @@ -307,6 +310,7 @@ class HQQMarlinMethod(LinearMethodBase): x, None, layer.marlin_qweight, + bias, scales, None, zeros, @@ -326,7 +330,4 @@ class HQQMarlinMethod(LinearMethodBase): if orig_type != torch.float16: marlin_out = marlin_out.to(orig_type) - if bias is not None: - marlin_out.add_(bias) - return marlin_out diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 73e0b17ea85a..5eb99383097b 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,8 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx, - marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) + marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, + unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) from vllm.platforms import current_platform @@ -111,6 +112,9 @@ class MarlinLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data = marlin_permute_bias(layer.bias) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bed502226716..8868c623796a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1330,6 +1330,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, + None, + None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 03fbcf158338..dbe6c603c062 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -15,13 +15,17 @@ from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( _can_support_mxfp4, _swizzle_mxfp4) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import next_power_of_2, round_up +from vllm.scalar_type import scalar_types +from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, + next_power_of_2, round_up) if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): @@ -81,6 +85,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__() self.topk_indices_dtype = None self.moe = moe + self.use_marlin = self._should_use_marlin() + + def _should_use_marlin(self): + if envs.VLLM_MXFP4_USE_MARLIN is not None: + return envs.VLLM_MXFP4_USE_MARLIN + if current_platform.is_cuda() and \ + not current_platform.has_device_capability(100): + if not current_platform.is_device_capability(90): + # marlin kernel has better performance on ampere + return True + if not has_triton_kernels(): + return True + if not is_torch_equal_or_newer("2.8.0"): + return True + return False def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -101,11 +120,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = \ intermediate_size_per_partition - # pad the intermediate size to be a multiple of 2 * mxfp4_block - # for to hold non-uniform sharded tensor as well as swizzling - # other padding to increase performance - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if self.use_marlin: + # The moe marlin kernel requires that for each linear + # n % 256 == 0 and k % 128 == 0. + # In gate_up_proj: + # n = 2 * intermediate_size_per_partition_after_pad + # k = hidden_size + # In down_proj + # n = hidden_size + # k = intermediate_size_per_partition_after_pad + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 256) + + layer.params_dtype = params_dtype + layer.num_experts = num_experts + layer.hidden_size = hidden_size + layer.intermediate_size_per_partition = \ + intermediate_size_per_partition_after_pad + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) @@ -191,8 +228,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -399,13 +438,45 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_bias, + layer.w2_bias, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=None, + global_scale2=None, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map) + assert _can_support_mxfp4( use_grouped_topk, topk_group, num_expert_group, expert_map, custom_routing_function, e_score_correction_bias, apply_router_weight_on_input, scoring_func, activation, expert_load_view, logical_to_physical_map, - logical_replica_count), ("MXFP4 are not supported\ - with this configuration.") + logical_replica_count), ( + "MXFP4 are not supported with this configuration.") if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 7540a1516fcb..02057b476c6e 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, @@ -410,6 +417,7 @@ def apply_gptq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -425,9 +433,6 @@ def apply_gptq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -456,6 +461,7 @@ def apply_awq_marlin_linear( output = ops.gptq_marlin_gemm(reshaped_x, None, weight, + bias, weight_scale, None, weight_zp, @@ -470,7 +476,4 @@ def apply_awq_marlin_linear( use_fp32_reduce=use_fp32_reduce, is_zp_float=False) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index ca10db69dc16..94ffdcd26ecd 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -22,7 +22,7 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(80) -def fp4_marlin_process_scales(marlin_scales): +def nvfp4_marlin_process_scales(marlin_scales): if not (marlin_scales >= 0).all(): logger.warning_once( "NVFP4 Marlin assumes the scales to be >=0, but has encountered " @@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales): return marlin_scales -def fp4_marlin_process_global_scale(global_scale): +def mxfp4_marlin_process_scales(marlin_scales): + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale): assert global_scale.dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 if global_scale.dtype == torch.half: @@ -73,7 +86,7 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], workspace: torch.Tensor, size_n: int, size_k: int, @@ -94,6 +107,7 @@ def apply_fp4_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=weight_scale_2, b_zeros=None, @@ -107,9 +121,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition param_dtype = layer.params_dtype @@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = layer.weight_scale.T.contiguous() + + if not is_nvfp4: + weight_scale = weight_scale.view(torch.float8_e8m0fnu) + + weight_scale = weight_scale.to(param_dtype) weight_scale = marlin_permute_scales(s=weight_scale, size_k=part_size_k, size_n=part_size_n, - group_size=16) - weight_scale = fp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + group_size=group_size) - weight_scale_2 = layer.weight_scale_2.to(param_dtype) - weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + if is_nvfp4: + weight_scale = nvfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + else: + weight_scale = mxfp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, + requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) return @@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads.") + is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + group_size = 16 if is_nvfp4 else 32 + e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition @@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales for name in ["w13", "w2"]: - scales = getattr(layer, name + "_weight_scale").to(param_dtype) - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + scales = getattr(layer, name + "_weight_scale") + if not is_nvfp4: + scales = scales.view(torch.float8_e8m0fnu) + scales = scales.to(param_dtype) + if is_nvfp4: + global_scale = getattr(layer, + name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_n, size_k = k, n for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i].T, + scale = scales[i].T + + marlin_scales = marlin_permute_scales(s=scale, size_k=size_k, size_n=size_n, - group_size=16) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + group_size=group_size) + if is_nvfp4: + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + else: + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) - global_scale = fp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, requires_grad=False) - setattr(layer, name + "_weight_scale_2", global_scale) + if is_nvfp4: + global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, + requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(param_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) -def rand_marlin_weight_fp4_like(weight, group_size): +def rand_marlin_weight_nvfp4_like(weight, group_size): assert group_size > 0 size_n, size_k = weight.shape device = weight.device @@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = fp4_marlin_process_scales(marlin_scales) + marlin_scales = nvfp4_marlin_process_scales(marlin_scales) - global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale) return weight_ref.T, marlin_qweight, marlin_scales, global_scale + + +def rand_marlin_weight_mxfp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = torch.randint(100, + 125, (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device) + scales = scales.view(torch.float8_e8m0fnu) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = mxfp4_marlin_process_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5372c49d9838..511e19545d5a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,8 +8,8 @@ import torch import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, - should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, + marlin_permute_scales, should_use_atomic_add_reduce) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -58,6 +58,7 @@ def apply_fp8_marlin_linear( output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, + b_bias=bias, b_scales=weight_scale, global_scale=None, b_zeros=None, @@ -71,9 +72,6 @@ def apply_fp8_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: - output.add_(bias) # In-place add - return output.reshape(out_shape) @@ -160,6 +158,11 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n, ) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, size_k_first: bool = True) -> None: @@ -274,6 +277,23 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, setattr(layer, name + "_weight_scale", scales) + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(layer.orig_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) + def pack_fp8_to_int32(fp8_tensor: torch.Tensor, size_k_first: bool = True) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 95eabe149d89..deeb69bcad0e 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -61,7 +61,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, scoring_func: str = "softmax", - activation: str = "silu", + activation: str = "swiglu_oai", expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None): diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 9060b55c79b0..6f11ab8e0300 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -327,6 +327,8 @@ class scalar_types: uint8 = ScalarType.uint(8, None) float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float8_e8m0fnu = ScalarType(8, 0, False, 0, True, + NanRepr.EXTD_RANGE_MAX_MIN) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10)