mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Kernel] LoRA triton kernels support PDL (#27402)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a736e5ff77
commit
21b82f4ea2
@ -6,6 +6,8 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .utils import supports_pdl
|
||||||
|
|
||||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -82,6 +84,8 @@ def _fused_moe_lora_kernel(
|
|||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
SPLIT_K: tl.constexpr,
|
SPLIT_K: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
|
IS_PRIMARY: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
slice_id = tl.program_id(axis=1)
|
slice_id = tl.program_id(axis=1)
|
||||||
@ -110,13 +114,11 @@ def _fused_moe_lora_kernel(
|
|||||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
|
||||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get the expert_id to process curr shard
|
# get the expert_id to process curr shard
|
||||||
ind = lora_id * stride_el + pid_m
|
ind = lora_id * stride_el + pid_m
|
||||||
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get a_ptr,b_ptr,c_ptr
|
# get a_ptr,b_ptr,c_ptr
|
||||||
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
||||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||||
@ -149,12 +151,17 @@ def _fused_moe_lora_kernel(
|
|||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, grid_k):
|
for k in range(0, grid_k):
|
||||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||||
|
# pre-fetch lora weight
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||||
|
# GDC wait waits for ALL programs in the the prior kernel to complete
|
||||||
|
# before continuing.
|
||||||
|
if USE_GDC and not IS_PRIMARY:
|
||||||
|
tl.extra.cuda.gdc_wait()
|
||||||
a = tl.load(
|
a = tl.load(
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
# Advance the ptrs to the next K block.
|
# Advance the ptrs to the next K block.
|
||||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||||
@ -163,12 +170,15 @@ def _fused_moe_lora_kernel(
|
|||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
if USE_GDC and IS_PRIMARY:
|
||||||
|
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||||
|
tl.extra.cuda.gdc_launch_dependents()
|
||||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
|
||||||
if SPLIT_K == 1:
|
if SPLIT_K == 1:
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
else:
|
else:
|
||||||
@ -209,7 +219,7 @@ def _fused_moe_lora_shrink(
|
|||||||
mul_routed_weight: bool = False,
|
mul_routed_weight: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
w1_lora_a_stacked = lora_a_stacked[0]
|
w1_lora_a_stacked = lora_a_stacked[0]
|
||||||
|
use_gdc = supports_pdl(qcurr_hidden_states.device)
|
||||||
shrink_config = {
|
shrink_config = {
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
@ -218,6 +228,8 @@ def _fused_moe_lora_shrink(
|
|||||||
"num_warps": num_warps,
|
"num_warps": num_warps,
|
||||||
"num_stages": num_stages,
|
"num_stages": num_stages,
|
||||||
"SPLIT_K": split_k,
|
"SPLIT_K": split_k,
|
||||||
|
"USE_GDC": use_gdc,
|
||||||
|
"launch_pdl": use_gdc, # triton kernel metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||||
@ -229,7 +241,6 @@ def _fused_moe_lora_shrink(
|
|||||||
len(lora_a_stacked),
|
len(lora_a_stacked),
|
||||||
lora_a_stacked[0].shape[0],
|
lora_a_stacked[0].shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
_fused_moe_lora_kernel[grid](
|
_fused_moe_lora_kernel[grid](
|
||||||
qcurr_hidden_states,
|
qcurr_hidden_states,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
@ -261,6 +272,7 @@ def _fused_moe_lora_shrink(
|
|||||||
num_slice_c=num_slices,
|
num_slice_c=num_slices,
|
||||||
top_k=1 if mul_routed_weight else top_k_num,
|
top_k=1 if mul_routed_weight else top_k_num,
|
||||||
MUL_ROUTED_WEIGHT=False,
|
MUL_ROUTED_WEIGHT=False,
|
||||||
|
IS_PRIMARY=True,
|
||||||
**shrink_config,
|
**shrink_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -314,7 +326,7 @@ def _fused_moe_lora_expand(
|
|||||||
dtype=output.dtype,
|
dtype=output.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
use_gdc = supports_pdl(a_intermediate_cache1.device)
|
||||||
expand_config = {
|
expand_config = {
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
@ -323,6 +335,8 @@ def _fused_moe_lora_expand(
|
|||||||
"num_warps": num_warps,
|
"num_warps": num_warps,
|
||||||
"num_stages": num_stages,
|
"num_stages": num_stages,
|
||||||
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
|
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
|
||||||
|
"USE_GDC": use_gdc,
|
||||||
|
"launch_pdl": use_gdc, # triton kernel metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
@ -361,6 +375,7 @@ def _fused_moe_lora_expand(
|
|||||||
num_slice_c=num_slices,
|
num_slice_c=num_slices,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
IS_PRIMARY=False,
|
||||||
**expand_config,
|
**expand_config,
|
||||||
)
|
)
|
||||||
for i in range(num_slices):
|
for i in range(num_slices):
|
||||||
|
|||||||
@ -22,6 +22,7 @@ def mm_k(
|
|||||||
SPLIT_K: tl.constexpr,
|
SPLIT_K: tl.constexpr,
|
||||||
CAST_TYPE: tl.constexpr,
|
CAST_TYPE: tl.constexpr,
|
||||||
b_dtype: tl.constexpr,
|
b_dtype: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||||
@ -45,19 +46,25 @@ def mm_k(
|
|||||||
CAST_TYPE: if True, cast the values from the A matrix to the B
|
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||||
matrix dtype.
|
matrix dtype.
|
||||||
b_dtype: datatype of the B matrix
|
b_dtype: datatype of the B matrix
|
||||||
|
USE_GDC: Whether to use PDL. True indicates use.
|
||||||
"""
|
"""
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||||
if EVEN_K:
|
if EVEN_K:
|
||||||
tiled_a = tl.load(a_ptr)
|
# pre-fetech lora weight
|
||||||
tiled_b = tl.load(b_ptr)
|
tiled_b = tl.load(b_ptr)
|
||||||
|
if USE_GDC:
|
||||||
|
tl.extra.cuda.gdc_wait()
|
||||||
|
tiled_a = tl.load(a_ptr)
|
||||||
else:
|
else:
|
||||||
tiled_a = tl.load(
|
|
||||||
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
|
|
||||||
)
|
|
||||||
tiled_b = tl.load(
|
tiled_b = tl.load(
|
||||||
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
|
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||||
)
|
)
|
||||||
|
if USE_GDC:
|
||||||
|
tl.extra.cuda.gdc_wait()
|
||||||
|
tiled_a = tl.load(
|
||||||
|
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||||
|
)
|
||||||
if CAST_TYPE:
|
if CAST_TYPE:
|
||||||
tiled_a = tiled_a.to(b_dtype)
|
tiled_a = tiled_a.to(b_dtype)
|
||||||
accumulator += tl.dot(
|
accumulator += tl.dot(
|
||||||
@ -102,6 +109,7 @@ def do_expand_kernel(
|
|||||||
EVEN_K: tl.constexpr,
|
EVEN_K: tl.constexpr,
|
||||||
CAST_TYPE: tl.constexpr,
|
CAST_TYPE: tl.constexpr,
|
||||||
ADD_INPUTS: tl.constexpr,
|
ADD_INPUTS: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given an array of integers that identifies the rows of A, ram,
|
Given an array of integers that identifies the rows of A, ram,
|
||||||
@ -154,6 +162,7 @@ def do_expand_kernel(
|
|||||||
|
|
||||||
# Compute the block matrix product.
|
# Compute the block matrix product.
|
||||||
SPLIT_K = 1
|
SPLIT_K = 1
|
||||||
|
|
||||||
accumulator = mm_k(
|
accumulator = mm_k(
|
||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
@ -168,6 +177,7 @@ def do_expand_kernel(
|
|||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
cur_lora_ptr.dtype.element_ty,
|
cur_lora_ptr.dtype.element_ty,
|
||||||
|
USE_GDC,
|
||||||
)
|
)
|
||||||
|
|
||||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||||
@ -223,6 +233,7 @@ def do_shrink_kernel(
|
|||||||
EVEN_K: tl.constexpr,
|
EVEN_K: tl.constexpr,
|
||||||
SPLIT_K: tl.constexpr,
|
SPLIT_K: tl.constexpr,
|
||||||
SLICE_NUM: tl.constexpr,
|
SLICE_NUM: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given an array of integers that identifies the rows of A, ram,
|
Given an array of integers that identifies the rows of A, ram,
|
||||||
@ -272,8 +283,11 @@ def do_shrink_kernel(
|
|||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
False,
|
False,
|
||||||
cur_lora_ptr.dtype.element_ty,
|
cur_lora_ptr.dtype.element_ty,
|
||||||
|
False, # USE_GDC is always False in shrink kernel
|
||||||
)
|
)
|
||||||
|
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||||
|
if USE_GDC:
|
||||||
|
tl.extra.cuda.gdc_launch_dependents()
|
||||||
# Identify the C output pointers to store the results of the accumulator.
|
# Identify the C output pointers to store the results of the accumulator.
|
||||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||||
offset_cm = tl.arange(0, BLOCK_M)
|
offset_cm = tl.arange(0, BLOCK_M)
|
||||||
@ -284,10 +298,10 @@ def do_shrink_kernel(
|
|||||||
+ offset_cn[None, :] * output_d2_stride
|
+ offset_cn[None, :] * output_d2_stride
|
||||||
)
|
)
|
||||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||||
|
|
||||||
accumulator *= scaling
|
accumulator *= scaling
|
||||||
|
|
||||||
# handles write-back with reduction-splitting
|
# handles write-back with reduction-splitting
|
||||||
if SPLIT_K == 1:
|
if SPLIT_K == 1:
|
||||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||||
else:
|
else:
|
||||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed")
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .utils import supports_pdl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _lora_expand_kernel(
|
def _lora_expand_kernel(
|
||||||
@ -45,6 +47,7 @@ def _lora_expand_kernel(
|
|||||||
CAST_TYPE: tl.constexpr,
|
CAST_TYPE: tl.constexpr,
|
||||||
SLICE_NUM: tl.constexpr,
|
SLICE_NUM: tl.constexpr,
|
||||||
SAME_STRIDE: tl.constexpr,
|
SAME_STRIDE: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
):
|
):
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||||
@ -121,6 +124,7 @@ def _lora_expand_kernel(
|
|||||||
EVEN_K,
|
EVEN_K,
|
||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
ADD_INPUTS,
|
ADD_INPUTS,
|
||||||
|
USE_GDC,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -236,7 +240,7 @@ def _lora_expand(
|
|||||||
# thread blocks simply exit.
|
# thread blocks simply exit.
|
||||||
MAX_LORAS,
|
MAX_LORAS,
|
||||||
)
|
)
|
||||||
|
use_gdc = supports_pdl(inputs.device)
|
||||||
_lora_expand_kernel[grid](
|
_lora_expand_kernel[grid](
|
||||||
inputs,
|
inputs,
|
||||||
lora_ptr_tensor,
|
lora_ptr_tensor,
|
||||||
@ -266,9 +270,11 @@ def _lora_expand(
|
|||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
NUM_SLICES,
|
NUM_SLICES,
|
||||||
same_stride,
|
same_stride,
|
||||||
|
use_gdc,
|
||||||
num_warps=NUM_WARPS,
|
num_warps=NUM_WARPS,
|
||||||
num_ctas=NUM_CTAS,
|
num_ctas=NUM_CTAS,
|
||||||
num_stages=NUM_STAGES,
|
num_stages=NUM_STAGES,
|
||||||
|
launch_pdl=use_gdc,
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
@ -14,6 +14,8 @@ from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .utils import supports_pdl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _lora_shrink_kernel(
|
def _lora_shrink_kernel(
|
||||||
@ -43,6 +45,7 @@ def _lora_shrink_kernel(
|
|||||||
SPLIT_K: tl.constexpr,
|
SPLIT_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
SLICE_NUM: tl.constexpr,
|
SLICE_NUM: tl.constexpr,
|
||||||
|
USE_GDC: tl.constexpr,
|
||||||
):
|
):
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||||
@ -83,7 +86,6 @@ def _lora_shrink_kernel(
|
|||||||
cta_lora_seq_indices = (
|
cta_lora_seq_indices = (
|
||||||
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load all relevant row indices.
|
# Load all relevant row indices.
|
||||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||||
@ -118,6 +120,7 @@ def _lora_shrink_kernel(
|
|||||||
EVEN_K,
|
EVEN_K,
|
||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
SLICE_NUM,
|
SLICE_NUM,
|
||||||
|
USE_GDC,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -217,7 +220,7 @@ def _lora_shrink(
|
|||||||
# thread blocks exit early.
|
# thread blocks exit early.
|
||||||
MAX_LORAS,
|
MAX_LORAS,
|
||||||
)
|
)
|
||||||
|
use_gdc = supports_pdl(inputs.device)
|
||||||
_lora_shrink_kernel[grid](
|
_lora_shrink_kernel[grid](
|
||||||
inputs,
|
inputs,
|
||||||
lora_ptr_tensor,
|
lora_ptr_tensor,
|
||||||
@ -245,9 +248,11 @@ def _lora_shrink(
|
|||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
GROUP_SIZE_M,
|
GROUP_SIZE_M,
|
||||||
NUM_SLICES,
|
NUM_SLICES,
|
||||||
|
use_gdc,
|
||||||
num_warps=NUM_WARPS,
|
num_warps=NUM_WARPS,
|
||||||
num_ctas=NUM_CTAS,
|
num_ctas=NUM_CTAS,
|
||||||
num_stages=NUM_STAGES,
|
num_stages=NUM_STAGES,
|
||||||
|
launch_pdl=use_gdc,
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -282,3 +284,12 @@ def get_lora_op_configs(
|
|||||||
|
|
||||||
assert config_data is not None
|
assert config_data is not None
|
||||||
return config_data
|
return config_data
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def supports_pdl(device: torch.device | None = None) -> bool:
|
||||||
|
"""
|
||||||
|
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
|
||||||
|
"""
|
||||||
|
# PDL requires compute capability SM90 or above
|
||||||
|
return current_platform.is_cuda() and current_platform.has_device_capability(90)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user