mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 16:15:52 +08:00
[Kernel] LoRA - Refactor sgmv kernels (#13110)
This commit is contained in:
parent
a64a84433d
commit
b69692a2d8
243
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
243
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Utilities for Punica kernel construction.
|
||||||
|
"""
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
|
||||||
|
b_dtype: tl.constexpr):
|
||||||
|
"""
|
||||||
|
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||||
|
B (k x n), iterate, through the K dimension to compute the partial/complete
|
||||||
|
matrix block product.
|
||||||
|
If SPLIT_K == 1, the output m x n product is complete.
|
||||||
|
If SPLIT_K > 1, the thread block computes partial outputs. The partial
|
||||||
|
outputs are then atomically summed in the caller code.
|
||||||
|
Args:
|
||||||
|
a_ptr: Array of pointers, identifying rows of A
|
||||||
|
b_ptr: Array of pointers, identifying columns of B
|
||||||
|
ak_stride: K dimension stride of the A matrix
|
||||||
|
bk_stride: K dimension stride of the B matrix
|
||||||
|
K: Length of the K dimension
|
||||||
|
BLOCK_M: M dimension of the output block m x n
|
||||||
|
BLOCK_N: N dimension of the output block m x n
|
||||||
|
BLOCK_K: K dimension atom
|
||||||
|
EVEN_K: True if the blocks of A and B can be loaded without any
|
||||||
|
masking.
|
||||||
|
SPLIT_K: Parameter signifying parallelism in the K dimension.
|
||||||
|
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||||
|
matrix dtype.
|
||||||
|
b_dtype: datatype of the B matrix
|
||||||
|
"""
|
||||||
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
|
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||||
|
if EVEN_K:
|
||||||
|
tiled_a = tl.load(a_ptr)
|
||||||
|
tiled_b = tl.load(b_ptr)
|
||||||
|
else:
|
||||||
|
tiled_a = tl.load(a_ptr,
|
||||||
|
mask=offset_k[None, :]
|
||||||
|
< K - k * (BLOCK_K * SPLIT_K),
|
||||||
|
other=0)
|
||||||
|
tiled_b = tl.load(b_ptr,
|
||||||
|
mask=offset_k[:, None]
|
||||||
|
< K - k * (BLOCK_K * SPLIT_K),
|
||||||
|
other=0)
|
||||||
|
if CAST_TYPE:
|
||||||
|
tiled_a = tiled_a.to(b_dtype)
|
||||||
|
accumulator += tl.dot(
|
||||||
|
tiled_a,
|
||||||
|
tiled_b,
|
||||||
|
)
|
||||||
|
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||||
|
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||||
|
return accumulator
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def do_expand_kernel(
|
||||||
|
pid_n,
|
||||||
|
lora_index,
|
||||||
|
slice_id,
|
||||||
|
input_ptr,
|
||||||
|
lora_ptr,
|
||||||
|
out_ptr,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
M_LEN,
|
||||||
|
ram, # array identifying the rows of Input ptr to operate on
|
||||||
|
slice_start_loc,
|
||||||
|
# input ptr strides
|
||||||
|
input_d0_stride,
|
||||||
|
input_d1_stride,
|
||||||
|
input_d2_stride,
|
||||||
|
# lora ptr strides
|
||||||
|
ls_d0_ptr,
|
||||||
|
ls_d1_ptr,
|
||||||
|
ls_d2_ptr,
|
||||||
|
# out ptr strides
|
||||||
|
output_d0_stride,
|
||||||
|
output_d1_stride,
|
||||||
|
# constants
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
SAME_STRIDE: tl.constexpr,
|
||||||
|
SLICE_NUM: tl.constexpr,
|
||||||
|
EVEN_K: tl.constexpr,
|
||||||
|
CAST_TYPE: tl.constexpr,
|
||||||
|
ADD_INPUTS: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given an array of integers that identifies the rows of A, ram,
|
||||||
|
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||||
|
a slice_id that identifies the input/output slice,
|
||||||
|
compute the matrix product and store in the appropriate output location.
|
||||||
|
Given that this is an expand kernel, we don't perform any split-K reduction
|
||||||
|
as the K dimension is assumed to be small.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ls_d*_ptr can be either an integer or a pointer
|
||||||
|
if SAME_STRIDE:
|
||||||
|
# integer
|
||||||
|
cur_lora_d0_stride = ls_d0_ptr
|
||||||
|
cur_lora_d1_stride = ls_d1_ptr
|
||||||
|
cur_lora_d2_stride = ls_d2_ptr
|
||||||
|
else:
|
||||||
|
# pointer
|
||||||
|
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||||
|
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||||
|
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||||
|
|
||||||
|
# Identify the input_ptr and lora_ptr from slice_id.
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
cur_input_ptr = input_ptr
|
||||||
|
cur_lora_ptr = lora_ptr
|
||||||
|
else:
|
||||||
|
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||||
|
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||||
|
tl.pointer_type(out_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
# Identify the column indices of B to process.
|
||||||
|
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
|
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||||
|
|
||||||
|
# Identify A and B block pointers
|
||||||
|
offset_k = tl.arange(0, BLOCK_K)
|
||||||
|
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
|
||||||
|
offset_k[None, :] * input_d2_stride, )
|
||||||
|
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||||
|
offset_k[:, None] * cur_lora_d2_stride +
|
||||||
|
rbn[None, :] * cur_lora_d1_stride)
|
||||||
|
|
||||||
|
# Compute the block matrix product.
|
||||||
|
SPLIT_K = 1
|
||||||
|
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
|
||||||
|
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
|
||||||
|
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
cur_slice_start = slice_start_loc
|
||||||
|
else:
|
||||||
|
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||||
|
|
||||||
|
# Identify the C output pointers to store the results of the accumulator.
|
||||||
|
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||||
|
offset_cm = tl.arange(0, BLOCK_M)
|
||||||
|
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
|
||||||
|
offset_cn[None, :] * output_d1_stride)
|
||||||
|
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
|
||||||
|
< (cur_slice_start + N))
|
||||||
|
|
||||||
|
if ADD_INPUTS:
|
||||||
|
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||||
|
tiled_c += tiled_out
|
||||||
|
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def do_shrink_kernel(
|
||||||
|
pid_n,
|
||||||
|
pid_sk,
|
||||||
|
slice_id,
|
||||||
|
lora_index,
|
||||||
|
input_ptr,
|
||||||
|
lora_ptr,
|
||||||
|
out_ptr,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
M_LEN,
|
||||||
|
ram,
|
||||||
|
# input strides
|
||||||
|
input_d0_stride,
|
||||||
|
input_d1_stride,
|
||||||
|
# lora strides
|
||||||
|
lora_d0_stride,
|
||||||
|
lora_d1_stride,
|
||||||
|
lora_d2_stride,
|
||||||
|
# output strides
|
||||||
|
output_d0_stride,
|
||||||
|
output_d1_stride,
|
||||||
|
output_d2_stride,
|
||||||
|
scaling,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
EVEN_K: tl.constexpr,
|
||||||
|
SPLIT_K: tl.constexpr,
|
||||||
|
SLICE_NUM: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given an array of integers that identifies the rows of A, ram,
|
||||||
|
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||||
|
a slice_id that identifies the input/output slice, compute the
|
||||||
|
matrix product and store in the appropriate output location.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Identify the lora_ptr from slice_id.
|
||||||
|
if SLICE_NUM == 1:
|
||||||
|
# current lora ptr
|
||||||
|
cur_lora_ptr = lora_ptr
|
||||||
|
else:
|
||||||
|
# current lora ptr
|
||||||
|
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||||
|
tl.pointer_type(input_ptr.dtype.element_ty))
|
||||||
|
|
||||||
|
# Identify the column indices of B to process.
|
||||||
|
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
|
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||||
|
|
||||||
|
# Identify A and B block pointers
|
||||||
|
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||||
|
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
|
||||||
|
offset_k[None, :] * input_d1_stride)
|
||||||
|
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||||
|
rbn[None, :] * lora_d1_stride +
|
||||||
|
offset_k[:, None] * lora_d2_stride)
|
||||||
|
|
||||||
|
# Compute partial/complete block matrix product.
|
||||||
|
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
|
||||||
|
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
|
||||||
|
cur_lora_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
# Identify the C output pointers to store the results of the accumulator.
|
||||||
|
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
|
offset_cm = tl.arange(0, BLOCK_M)
|
||||||
|
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
||||||
|
slice_id * output_d0_stride)
|
||||||
|
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
|
||||||
|
None, :] * output_d2_stride
|
||||||
|
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||||
|
|
||||||
|
accumulator *= scaling
|
||||||
|
# handles write-back with reduction-splitting
|
||||||
|
if SPLIT_K == 1:
|
||||||
|
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||||
|
else:
|
||||||
|
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||||
@ -14,6 +14,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .kernel_utils import do_expand_kernel
|
||||||
from .utils import _get_lora_b_ptr
|
from .utils import _get_lora_b_ptr
|
||||||
|
|
||||||
|
|
||||||
@ -63,86 +64,56 @@ def _sgmv_expand_kernel(
|
|||||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||||
pid_m = pid // cta_n_num
|
pid_m = pid // cta_n_num
|
||||||
pid_n = pid % cta_n_num
|
pid_n = pid % cta_n_num
|
||||||
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
M = tl.load(seq_lens + cur_batch)
|
||||||
if pid_m * BLOCK_M > M:
|
if pid_m * BLOCK_M >= M:
|
||||||
return
|
return
|
||||||
if pid_n * BLOCK_N > curr_N:
|
if pid_n * BLOCK_N >= curr_N:
|
||||||
return
|
return
|
||||||
lora_index = tl.load(lora_indices + cur_batch)
|
lora_index = tl.load(lora_indices + cur_batch)
|
||||||
if lora_index == -1:
|
if lora_index == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
m_offset = tl.load(b_seq_start_loc + cur_batch)
|
||||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
|
||||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
|
||||||
offset_k = tl.arange(0, BLOCK_K)
|
|
||||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
|
||||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N),
|
|
||||||
BLOCK_N)
|
|
||||||
# ls_d*_ptr can be either an integer or a pointer
|
|
||||||
if SAME_STRIDE:
|
|
||||||
# integer
|
|
||||||
cur_lora_d0_stride = ls_d0_ptr
|
|
||||||
cur_lora_d1_stride = ls_d1_ptr
|
|
||||||
cur_lora_d2_stride = ls_d2_ptr
|
|
||||||
else:
|
|
||||||
# pointer
|
|
||||||
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
|
||||||
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
|
||||||
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
|
||||||
if SLICE_NUM == 1:
|
|
||||||
cur_input_ptr = input_ptr
|
|
||||||
cur_lora_ptr = lora_ptr
|
|
||||||
|
|
||||||
else:
|
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
|
||||||
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
cta_m_offset = m_offset + (pid_m * BLOCK_M)
|
||||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
offset_m = tl.arange(0, BLOCK_M)
|
||||||
tl.pointer_type(out_ptr.dtype.element_ty))
|
ram = cta_m_offset + tl.max_contiguous(
|
||||||
|
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
|
||||||
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
|
do_expand_kernel(
|
||||||
ram[:, None] * input_d1_stride +
|
pid_n,
|
||||||
offset_k[None, :] * input_d2_stride, )
|
lora_index,
|
||||||
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
slice_id,
|
||||||
offset_k[:, None] * cur_lora_d2_stride +
|
input_ptr,
|
||||||
rbn[None, :] * cur_lora_d1_stride)
|
lora_ptr,
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
out_ptr,
|
||||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
curr_N,
|
||||||
if EVEN_K:
|
K,
|
||||||
tiled_a = tl.load(a_ptr)
|
cta_m_len,
|
||||||
tiled_b = tl.load(b_ptr)
|
ram, # array identifying the rows of Input ptr to operate on
|
||||||
else:
|
slice_start_loc,
|
||||||
tiled_a = tl.load(a_ptr,
|
# input ptr strides
|
||||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
input_d0_stride,
|
||||||
other=0)
|
input_d1_stride,
|
||||||
tiled_b = tl.load(b_ptr,
|
input_d2_stride,
|
||||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
# lora ptr strides
|
||||||
other=0)
|
ls_d0_ptr,
|
||||||
if CAST_TYPE:
|
ls_d1_ptr,
|
||||||
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
|
ls_d2_ptr,
|
||||||
accumulator += tl.dot(
|
# out ptr strides
|
||||||
tiled_a,
|
output_d0_stride,
|
||||||
tiled_b,
|
output_d1_stride,
|
||||||
)
|
# constants
|
||||||
a_ptr += BLOCK_K * input_d2_stride
|
BLOCK_M,
|
||||||
b_ptr += BLOCK_K * cur_lora_d2_stride
|
BLOCK_N,
|
||||||
|
BLOCK_K,
|
||||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
SAME_STRIDE,
|
||||||
if SLICE_NUM == 1:
|
SLICE_NUM,
|
||||||
cur_slice_start = slice_start_loc
|
EVEN_K,
|
||||||
else:
|
CAST_TYPE,
|
||||||
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
ADD_INPUTS,
|
||||||
|
)
|
||||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
|
||||||
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
|
|
||||||
offset_cn[None, :] * output_d1_stride)
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
|
||||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
|
|
||||||
offset_cn[None, :] < (cur_slice_start + curr_N))
|
|
||||||
if ADD_INPUTS:
|
|
||||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
|
||||||
tiled_c += tiled_out
|
|
||||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .kernel_utils import do_shrink_kernel
|
||||||
from .utils import _get_lora_a_ptr
|
from .utils import _get_lora_a_ptr
|
||||||
|
|
||||||
|
|
||||||
@ -62,67 +63,50 @@ def _sgmv_shrink_kernel(
|
|||||||
pid_sk = pid_mix % SPLIT_K
|
pid_sk = pid_mix % SPLIT_K
|
||||||
|
|
||||||
M = tl.load(seq_lens + cur_batch)
|
M = tl.load(seq_lens + cur_batch)
|
||||||
if pid_m * BLOCK_M > M:
|
if pid_m * BLOCK_M >= M:
|
||||||
return
|
return
|
||||||
lora_index = tl.load(lora_indices + cur_batch)
|
lora_index = tl.load(lora_indices + cur_batch)
|
||||||
if lora_index == -1:
|
if lora_index == -1:
|
||||||
return
|
return
|
||||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
|
||||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
|
||||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
|
||||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
||||||
|
|
||||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
m_offset = tl.load(b_seq_start_loc + cur_batch)
|
||||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
|
||||||
# input ptr
|
|
||||||
a_ptr = (input_ptr + cur_seq_start * input_d0_stride +
|
|
||||||
ram[:, None] * input_d0_stride +
|
|
||||||
offset_k[None, :] * input_d1_stride)
|
|
||||||
|
|
||||||
if SLICE_NUM == 1:
|
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
|
||||||
# current lora ptr
|
cta_m_offset = m_offset + (pid_m * BLOCK_M)
|
||||||
cur_lora_ptr = lora_ptr
|
offset_m = tl.arange(0, BLOCK_M)
|
||||||
else:
|
ram = cta_m_offset + tl.max_contiguous(
|
||||||
# current lora ptr
|
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
|
||||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
|
||||||
tl.pointer_type(input_ptr.dtype.element_ty))
|
|
||||||
|
|
||||||
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
do_shrink_kernel(
|
||||||
rbn[None, :] * lora_d1_stride +
|
pid_n,
|
||||||
offset_k[:, None] * lora_d2_stride)
|
pid_sk,
|
||||||
|
slice_id,
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
lora_index,
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
input_ptr,
|
||||||
if EVEN_K:
|
lora_ptr,
|
||||||
tiled_a = tl.load(a_ptr)
|
out_ptr,
|
||||||
tiled_b = tl.load(b_ptr)
|
N,
|
||||||
else:
|
K,
|
||||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
cta_m_len,
|
||||||
tiled_a = tl.load(a_ptr,
|
ram,
|
||||||
mask=offset_k[None, :] < k_remaining,
|
# input strides
|
||||||
other=0.0)
|
input_d0_stride,
|
||||||
tiled_b = tl.load(b_ptr,
|
input_d1_stride,
|
||||||
mask=offset_k[:, None] < k_remaining,
|
# lora strides
|
||||||
other=0.0)
|
lora_d0_stride,
|
||||||
accumulator += tl.dot(tiled_a, tiled_b)
|
lora_d1_stride,
|
||||||
|
lora_d2_stride,
|
||||||
a_ptr += BLOCK_K * SPLIT_K * input_d1_stride
|
# output strides
|
||||||
b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride
|
output_d0_stride,
|
||||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
output_d1_stride,
|
||||||
|
output_d2_stride,
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
scaling,
|
||||||
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
BLOCK_M,
|
||||||
slice_id * output_d0_stride)
|
BLOCK_N,
|
||||||
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
|
BLOCK_K,
|
||||||
None, :] * output_d2_stride
|
EVEN_K,
|
||||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :]
|
SPLIT_K,
|
||||||
< N)
|
SLICE_NUM)
|
||||||
accumulator *= scaling
|
|
||||||
# handles write-back with reduction-splitting
|
|
||||||
if SPLIT_K == 1:
|
|
||||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
|
||||||
else:
|
|
||||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user