mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:14:56 +08:00
[Bugfix][Kernel]: Fix AllSpark kernel compilation errors and enable for CUDA < 12.0 (#14430)
Signed-off-by: wyj371990 <wyj371990@alibaba-inc.com>
This commit is contained in:
parent
73deea2fdb
commit
977a16772c
@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
|
||||
if (ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
@ -330,7 +330,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
||||
FType low16 = static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType low16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 =
|
||||
static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
||||
(reinterpret_cast<uint32_t&>(high16) << 16);
|
||||
int sts_offset =
|
||||
@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
FT scale_reg[4];
|
||||
*(reinterpret_cast<uint2*>(scale_reg)) =
|
||||
*(reinterpret_cast<const uint2*>(scales + params_nidx));
|
||||
FT zero_reg[4] = {0};
|
||||
FT zero_reg[4];
|
||||
if (zeros != nullptr) {
|
||||
*(reinterpret_cast<uint2*>(zero_reg)) =
|
||||
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
|
||||
@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < 4; ++ki) {
|
||||
fval_reg[ni * 4 + ki] =
|
||||
(fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
|
||||
if (zeros != nullptr) {
|
||||
fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]);
|
||||
}
|
||||
fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]);
|
||||
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
|
||||
((ni + lane_id % 4) % 4) * 8;
|
||||
smem[sts_offset] = fval_reg[ni * 4 + ki];
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
using marlin::ScalarType;
|
||||
|
||||
namespace allspark {
|
||||
|
||||
@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
return;
|
||||
}
|
||||
|
||||
FType sum(0);
|
||||
float sum = 0.f;
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += C_split[idx + i * matrix_size];
|
||||
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
}
|
||||
|
||||
C[idx] = sum;
|
||||
C[idx] = ScalarType<FType>::float2num(sum);
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user