diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 6d6da5f3d874..9da724a1b43c 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} namespace vllm { namespace awq { -// Pack two half values. -static inline __device__ __host__ unsigned __pack_half2(const half x, - const half y) { - unsigned v0 = *((unsigned short*)&x); - unsigned v1 = *((unsigned short*)&y); - return (v1 << 16) | v0; -} - template __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, @@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64) __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (N + 8)]; - __shared__ half scaling_factors_shared[N]; - __shared__ half zeros_shared[N]; - int j_factors1 = ((OC + N - 1) / N); - int blockIdx_x = 0; int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); @@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64) static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride = 2 * 32 * 8 / N; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + @@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / - // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x - // % (cta_N / 8)) * 8); // - zero and * scale // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = // q * scale - zero * scale. @@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, int G) { - int j_factors1 = 4; - int row_stride2 = 4; - int split_k_iters = 1; static constexpr uint32_t ZERO = 0x0; half B_shared[32 * (128 + 8)]; half* B_shared_ptr2 = B_shared; - half B_shared_warp[32]; - int OC = 512; - int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = blockIdx.y * blockDim.y + threadIdx.y;