[Kernel] LoRA - Refactor sgmv kernels (#13110)

This commit is contained in:
Varun Sundar Rabindranath 2025-02-20 17:58:06 +05:30 committed by GitHub
parent a64a84433d
commit b69692a2d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 324 additions and 126 deletions

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
"""
Utilities for Punica kernel construction.
"""
import triton
import triton.language as tl
@triton.jit
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
return accumulator
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
< (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
rbn[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
# Compute partial/complete block matrix product.
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
cur_lora_ptr.dtype.element_ty)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)

View File

@ -14,6 +14,7 @@ import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_expand_kernel
from .utils import _get_lora_b_ptr
@ -63,86 +64,56 @@ def _sgmv_expand_kernel(
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
if pid_m * BLOCK_M >= M:
return
if pid_n * BLOCK_N > curr_N:
if pid_n * BLOCK_N >= curr_N:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N),
BLOCK_N)
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
m_offset = tl.load(b_seq_start_loc + cur_batch)
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * input_d2_stride
b_ptr += BLOCK_K * cur_lora_d2_stride
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
offset_cn[None, :] < (cur_slice_start + curr_N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
)
@torch.inference_mode()

View File

@ -14,6 +14,7 @@ import triton.language as tl
from vllm.utils import direct_register_custom_op
from .kernel_utils import do_shrink_kernel
from .utils import _get_lora_a_ptr
@ -62,67 +63,50 @@ def _sgmv_shrink_kernel(
pid_sk = pid_mix % SPLIT_K
M = tl.load(seq_lens + cur_batch)
if pid_m * BLOCK_M > M:
if pid_m * BLOCK_M >= M:
return
lora_index = tl.load(lora_indices + cur_batch)
if lora_index == -1:
return
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# input ptr
a_ptr = (input_ptr + cur_seq_start * input_d0_stride +
ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
m_offset = tl.load(b_seq_start_loc + cur_batch)
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M))
cta_m_offset = m_offset + (pid_m * BLOCK_M)
offset_m = tl.arange(0, BLOCK_M)
ram = cta_m_offset + tl.max_contiguous(
tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
rbn[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < k_remaining,
other=0.0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < k_remaining,
other=0.0)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += BLOCK_K * SPLIT_K * input_d1_stride
b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :]
< N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()