mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Kernel] Remove unused variables in awq/gemm_kernels.cu (#6908)
This commit is contained in:
parent
9f69d8245a
commit
aae6d36f7e
@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||
const half y) {
|
||||
unsigned v0 = *((unsigned short*)&x);
|
||||
unsigned v1 = *((unsigned short*)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__global__ void __launch_bounds__(64)
|
||||
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||
@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (N + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[N];
|
||||
__shared__ half zeros_shared[N];
|
||||
|
||||
int j_factors1 = ((OC + N - 1) / N);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag =
|
||||
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||
@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
|
||||
uint32_t B_loaded =
|
||||
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||
// 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||
// % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||
// q * scale - zero * scale.
|
||||
@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
|
||||
__global__ void __launch_bounds__(64)
|
||||
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user