[Bugfix][Kernel]: Fix AllSpark kernel compilation errors and enable for CUDA < 12.0 (#14430)

Signed-off-by: wyj371990 <wyj371990@alibaba-inc.com>
This commit is contained in:
Yajie Wang 2025-03-15 00:55:14 +08:00 committed by GitHub
parent 73deea2fdb
commit 977a16772c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 10 deletions

View File

@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_repack.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
@ -330,7 +330,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures, or CUDA not >= 12.0")
" in CUDA target architectures")
endif()

View File

@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 = static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2]);
FType low16 =
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 =
static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset =
@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
FT scale_reg[4];
*(reinterpret_cast<uint2*>(scale_reg)) =
*(reinterpret_cast<const uint2*>(scales + params_nidx));
FT zero_reg[4] = {0};
FT zero_reg[4];
if (zeros != nullptr) {
*(reinterpret_cast<uint2*>(zero_reg)) =
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
#pragma unroll
for (int ki = 0; ki < 4; ++ki) {
fval_reg[ni * 4 + ki] =
(fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
if (zeros != nullptr) {
fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]);
}
fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]);
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
((ni + lane_id % 4) % 4) * 8;
smem[sts_offset] = fval_reg[ni * 4 + ki];

View File

@ -7,6 +7,8 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType;
namespace allspark {
@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
return;
}
FType sum(0);
float sum = 0.f;
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) {
sum += C_split[idx + i * matrix_size];
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
}
C[idx] = sum;
C[idx] = ScalarType<FType>::float2num(sum);
}
template <typename FType>