mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 07:43:05 +08:00
[LoRA] Lora shrink swizzle (#27694)
Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
b13a447546
commit
6ddae74054
@ -41,6 +41,7 @@ def _lora_shrink_kernel(
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
):
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
@ -48,8 +49,16 @@ def _lora_shrink_kernel(
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
|
||||
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
|
||||
|
||||
pid_m_n = pid_sk_m_n // SPLIT_K
|
||||
num_pid_in_group = GROUP_SIZE_M * cta_n_num
|
||||
group_id = pid_m_n // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
|
||||
|
||||
# Column-major ordering within groups for better cache reuse
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
@ -194,6 +203,7 @@ def _lora_shrink(
|
||||
NUM_WARPS = kernel_config["num_warps"]
|
||||
NUM_STAGES = kernel_config["num_stages"]
|
||||
NUM_CTAS = kernel_config["num_ctas"]
|
||||
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
@ -233,6 +243,7 @@ def _lora_shrink(
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
GROUP_SIZE_M,
|
||||
NUM_SLICES,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
|
||||
@ -199,6 +199,7 @@ def get_lora_op_configs(
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user