mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 06:01:25 +08:00
fix output zeroing race condition in GPTQ GEMM kernels
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
parent
7326fef568
commit
a36191a774
@ -233,11 +233,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
// if (blockIdx.z == 0) {
|
||||
// for (int m = 0; m < m_count; m++)
|
||||
// *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
// }
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
@ -372,11 +367,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (blockIdx.z == 0) {
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
@ -494,11 +484,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (blockIdx.z == 0) {
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
@ -623,11 +608,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (blockIdx.z == 0) {
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
@ -1224,9 +1204,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
|
||||
}
|
||||
|
||||
if (blockIdx.z == 0) {
|
||||
for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int i = width * h + w;
|
||||
@ -1319,9 +1296,6 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
if (blockIdx.z == 0) {
|
||||
for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int i = width * h + w;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user