diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 72627df24b9a..dafab501ee00 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -42,7 +42,7 @@ namespace marlin { __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) {} + int size_k, int lda, int block_rows) {} template size_m) { @@ -467,16 +467,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } int cur_block_rows = finish_row - start_row; - int row_stride = size_k * sizeof(half) / 16; + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; auto permute_row = [&](int row) { int iters = size_k / default_threads; int rest = size_k % default_threads; - int offset = row * row_stride; + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); + half const* a_row_half = + reinterpret_cast(a_int4_ptr + input_offset); + half* out_half = reinterpret_cast(out_int4_ptr + output_offset); int base_k = 0; @@ -537,6 +540,7 @@ __global__ void Marlin( int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce bool use_fp32_reduce // whether to use fp32 global reduce @@ -600,7 +604,7 @@ __global__ void Marlin( // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; @@ -631,7 +635,7 @@ __global__ void Marlin( } } if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; + A += 16 * thread_m_blocks * lda / 8; C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; @@ -643,7 +647,7 @@ __global__ void Marlin( // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; + int a_gl_stride = lda / 8; // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory @@ -1780,8 +1784,8 @@ __global__ void Marlin( HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \ <<>>( \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \ - use_fp32_reduce); \ + num_groups, prob_m, prob_n, prob_k, lda, locks, \ + use_atomic_add, use_fp32_reduce); \ } \ } @@ -2071,7 +2075,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, void* workspace, + int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, @@ -2184,8 +2188,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // Permute A columns int block_rows = div_ceil(prob_m, blocks); permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); A_ptr = a_tmp_ptr; + lda = prob_k; } // If we have a full K, then we can run the non-act-order version of Marlin @@ -2244,7 +2249,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ", num_bits = ", num_bits); } - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + A_ptr += 16 * thread_m_blocks * (lda / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } } @@ -2300,7 +2305,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + TORCH_CHECK(a.stride(1) == 1, "A.stride(1) is not 1"); + // We use int4 (16 bytes) to load A, so A must aligned to 16 bytes + TORCH_CHECK(a.stride(0) % 8 == 0, "A.stride(0) must divisible by 8"); + TORCH_CHECK(((uint64_t)a.data_ptr()) % 16 == 0, "A must aligned to 16 bytes"); TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); @@ -2432,7 +2440,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), - a_tmp.data_ptr(), size_m, size_n, size_k, + a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_atomic_add, @@ -2443,10 +2451,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_atomic_add, - use_fp32_reduce, is_zp_float); + a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + marlin::max_par, use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index c0cf5b099f99..3165201aa353 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -606,6 +606,51 @@ def test_marlin_qqq_gemm( assert max_diff < 0.04 +def test_marlin_gemm_subset_input(): + quant_type = scalar_types.uint4b8 + group_size = 128 + + size_m, size_k, size_n = 32, 1024, 2048 + big_m = size_m * 2 + big_k = size_k * 2 + + a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8] + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, False) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + marlin_zp, + g_idx, + sort_indices, + workspace.scratch, + quant_type, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full=True, + has_zp=False, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 + + def test_marlin_gemm_opcheck(): size_m = 2048 size_n = 4096