mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
[Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (#8248)
This commit is contained in:
parent
1447c97e75
commit
9db52eab3d
@ -22,7 +22,7 @@ def awq_dequantize_kernel(
|
||||
|
||||
# Compute offsets and masks for qweight_ptr.
|
||||
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
||||
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
|
||||
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
|
||||
|
||||
masks_y = offsets_y < num_rows
|
||||
@ -43,6 +43,9 @@ def awq_dequantize_kernel(
|
||||
|
||||
# Load the weights.
|
||||
iweights = tl.load(qweight_ptr + offsets, masks)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
iweights = tl.interleave(iweights, iweights)
|
||||
|
||||
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
# that will map given indices to the correct order.
|
||||
@ -59,9 +62,8 @@ def awq_dequantize_kernel(
|
||||
iweights = (iweights >> shifts) & 0xF
|
||||
|
||||
# Compute zero offsets and masks.
|
||||
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
|
||||
tl.arange(0, BLOCK_SIZE_Y) // group_size)
|
||||
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
|
||||
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
||||
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
|
||||
|
||||
zero_masks_y = zero_offsets_y < num_rows // group_size
|
||||
@ -70,13 +72,16 @@ def awq_dequantize_kernel(
|
||||
|
||||
# Load the zeros.
|
||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
|
||||
# Compute scale offsets and masks.
|
||||
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
|
||||
tl.arange(0, BLOCK_SIZE_Y) // group_size)
|
||||
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
||||
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
|
||||
tl.arange(0, BLOCK_SIZE_X * 8))
|
||||
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
|
||||
@ -87,6 +92,7 @@ def awq_dequantize_kernel(
|
||||
|
||||
# Load the scales.
|
||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||
|
||||
# Dequantize.
|
||||
iweights = (iweights - zeros) * scales
|
||||
@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +
|
||||
tl.arange(0, BLOCK_SIZE_N) // 8)
|
||||
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_bn = offsets_bn < N // 8
|
||||
|
||||
offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +
|
||||
tl.arange(0, BLOCK_SIZE_N) // 8)
|
||||
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
||||
masks_zn = offsets_zn < N // 8
|
||||
|
||||
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
b = tl.interleave(b, b)
|
||||
|
||||
# Dequantize b.
|
||||
offsets_szk = (
|
||||
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
|
||||
tl.arange(0, BLOCK_SIZE_K) // group_size)
|
||||
tl.arange(0, 1))
|
||||
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
|
||||
masks_zk = offsets_szk < K // group_size
|
||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||
zeros_ptrs = zeros_ptr + offsets_z
|
||||
zeros = tl.load(zeros_ptrs, mask=masks_z)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.interleave(zeros, zeros)
|
||||
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
|
||||
masks_sk = offsets_szk < K // group_size
|
||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||
scales_ptrs = scales_ptr + offsets_s
|
||||
scales = tl.load(scales_ptrs, mask=masks_s)
|
||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
b = (b >> shifts) & 0xF
|
||||
zeros = (zeros >> shifts) & 0xF
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user