diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py new file mode 100644 index 0000000000000..3572d3018622a --- /dev/null +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Utilities for Punica kernel construction. +""" +import triton +import triton.language as tl + + +@triton.jit +def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr): + """ + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate, through the K dimension to compute the partial/complete + matrix block product. + If SPLIT_K == 1, the output m x n product is complete. + If SPLIT_K > 1, the thread block computes partial outputs. The partial + outputs are then atomically summed in the caller code. + Args: + a_ptr: Array of pointers, identifying rows of A + b_ptr: Array of pointers, identifying columns of B + ak_stride: K dimension stride of the A matrix + bk_stride: K dimension stride of the B matrix + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without any + masking. + SPLIT_K: Parameter signifying parallelism in the K dimension. + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * SPLIT_K * ak_stride + b_ptr += BLOCK_K * SPLIT_K * bk_stride + return accumulator + + +@triton.jit +def do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product and store in the appropriate output location. + Given that this is an expand kernel, we don't perform any split-K reduction + as the K dimension is assumed to be small. + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + + # Compute the block matrix product. + SPLIT_K = 1 + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, + CAST_TYPE, cur_lora_ptr.dtype.element_ty) + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] + < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@triton.jit +def do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + # Compute partial/complete block matrix product. + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, + cur_lora_ptr.dtype.element_ty) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) diff --git a/vllm/lora/ops/triton_ops/sgmv_expand.py b/vllm/lora/ops/triton_ops/sgmv_expand.py index a8e71cacfe5a2..6aa3eafaba4c0 100644 --- a/vllm/lora/ops/triton_ops/sgmv_expand.py +++ b/vllm/lora/ops/triton_ops/sgmv_expand.py @@ -14,6 +14,7 @@ import triton.language as tl from vllm.utils import direct_register_custom_op +from .kernel_utils import do_expand_kernel from .utils import _get_lora_b_ptr @@ -63,86 +64,56 @@ def _sgmv_expand_kernel( curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) pid_m = pid // cta_n_num pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: + if pid_m * BLOCK_M >= M: return - if pid_n * BLOCK_N > curr_N: + if pid_n * BLOCK_N >= curr_N: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), - BLOCK_N) - # ls_d*_ptr can be either an integer or a pointer - if SAME_STRIDE: - # integer - cur_lora_d0_stride = ls_d0_ptr - cur_lora_d1_stride = ls_d1_ptr - cur_lora_d2_stride = ls_d2_ptr - else: - # pointer - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - if SLICE_NUM == 1: - cur_input_ptr = input_ptr - cur_lora_ptr = lora_ptr + m_offset = tl.load(b_seq_start_loc + cur_batch) - else: - cur_input_ptr = input_ptr + slice_id * input_d0_stride - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) - - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) - if CAST_TYPE: - tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * input_d2_stride - b_ptr += BLOCK_K * cur_lora_d2_stride - - tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - if SLICE_NUM == 1: - cur_slice_start = slice_start_loc - else: - cur_slice_start = tl.load(slice_start_loc + slice_id) - - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start - c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + - offset_cn[None, :] * output_d1_stride) - M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & ( - offset_cn[None, :] < (cur_slice_start + curr_N)) - if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) - tiled_c += tiled_out - tl.store(c_ptr, tiled_c, mask=c_mask) + cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) + cta_m_offset = m_offset + (pid_m * BLOCK_M) + offset_m = tl.arange(0, BLOCK_M) + ram = cta_m_offset + tl.max_contiguous( + tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) + do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS, + ) @torch.inference_mode() diff --git a/vllm/lora/ops/triton_ops/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py index 8b26583c11c14..b8ed0b020f9ac 100644 --- a/vllm/lora/ops/triton_ops/sgmv_shrink.py +++ b/vllm/lora/ops/triton_ops/sgmv_shrink.py @@ -14,6 +14,7 @@ import triton.language as tl from vllm.utils import direct_register_custom_op +from .kernel_utils import do_shrink_kernel from .utils import _get_lora_a_ptr @@ -62,67 +63,50 @@ def _sgmv_shrink_kernel( pid_sk = pid_mix % SPLIT_K M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: + if pid_m * BLOCK_M >= M: return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - # input ptr - a_ptr = (input_ptr + cur_seq_start * input_d0_stride + - ram[:, None] * input_d0_stride + - offset_k[None, :] * input_d1_stride) + m_offset = tl.load(b_seq_start_loc + cur_batch) - if SLICE_NUM == 1: - # current lora ptr - cur_lora_ptr = lora_ptr - else: - # current lora ptr - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) + cta_m_len = min(BLOCK_M, M - (pid_m * BLOCK_M)) + cta_m_offset = m_offset + (pid_m * BLOCK_M) + offset_m = tl.arange(0, BLOCK_M) + ram = cta_m_offset + tl.max_contiguous( + tl.multiple_of(offset_m % cta_m_len, BLOCK_M), BLOCK_M) - b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + - rbn[None, :] * lora_d1_stride + - offset_k[:, None] * lora_d2_stride) - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < k_remaining, - other=0.0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < k_remaining, - other=0.0) - accumulator += tl.dot(tiled_a, tiled_b) - - a_ptr += BLOCK_K * SPLIT_K * input_d1_stride - b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + - slice_id * output_d0_stride) - c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ - None, :] * output_d2_stride - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] - < N) - accumulator *= scaling - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(c_ptr, accumulator, mask=c_mask) - else: - tl.atomic_add(c_ptr, accumulator, mask=c_mask) + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) @torch.inference_mode()