mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 06:35:01 +08:00
[Bugfix][Kernel]: Fix AllSpark kernel compilation errors and enable for CUDA < 12.0 (#14430)
Signed-off-by: wyj371990 <wyj371990@alibaba-inc.com>
This commit is contained in:
parent
73deea2fdb
commit
977a16772c
@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
|
if (ALLSPARK_ARCHS)
|
||||||
set(ALLSPARK_SRCS
|
set(ALLSPARK_SRCS
|
||||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||||
@ -330,7 +330,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
" in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
||||||
FType low16 = static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2]);
|
FType low16 =
|
||||||
|
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||||
FType high16 =
|
FType high16 =
|
||||||
static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||||
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
||||||
(reinterpret_cast<uint32_t&>(high16) << 16);
|
(reinterpret_cast<uint32_t&>(high16) << 16);
|
||||||
int sts_offset =
|
int sts_offset =
|
||||||
@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
|||||||
FT scale_reg[4];
|
FT scale_reg[4];
|
||||||
*(reinterpret_cast<uint2*>(scale_reg)) =
|
*(reinterpret_cast<uint2*>(scale_reg)) =
|
||||||
*(reinterpret_cast<const uint2*>(scales + params_nidx));
|
*(reinterpret_cast<const uint2*>(scales + params_nidx));
|
||||||
FT zero_reg[4] = {0};
|
FT zero_reg[4];
|
||||||
if (zeros != nullptr) {
|
if (zeros != nullptr) {
|
||||||
*(reinterpret_cast<uint2*>(zero_reg)) =
|
*(reinterpret_cast<uint2*>(zero_reg)) =
|
||||||
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
|
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
|
||||||
@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
|||||||
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
|
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ki = 0; ki < 4; ++ki) {
|
for (int ki = 0; ki < 4; ++ki) {
|
||||||
fval_reg[ni * 4 + ki] =
|
if (zeros != nullptr) {
|
||||||
(fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
|
fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]);
|
||||||
|
}
|
||||||
|
fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]);
|
||||||
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
|
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
|
||||||
((ni + lane_id % 4) % 4) * 8;
|
((ni + lane_id % 4) % 4) * 8;
|
||||||
smem[sts_offset] = fval_reg[ni * 4 + ki];
|
smem[sts_offset] = fval_reg[ni * 4 + ki];
|
||||||
|
|||||||
@ -7,6 +7,8 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||||
|
using marlin::ScalarType;
|
||||||
|
|
||||||
namespace allspark {
|
namespace allspark {
|
||||||
|
|
||||||
@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
FType sum(0);
|
float sum = 0.f;
|
||||||
|
|
||||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||||
for (int i = 0; i < n_mat; ++i) {
|
for (int i = 0; i < n_mat; ++i) {
|
||||||
sum += C_split[idx + i * matrix_size];
|
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||||
}
|
}
|
||||||
|
|
||||||
C[idx] = sum;
|
C[idx] = ScalarType<FType>::float2num(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FType>
|
template <typename FType>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user