[LoRA] Lora shrink swizzle (#27694)

Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com>
Signed-off-by: Haipeng Li <li2haipeng@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
li2haipeng 2025-11-03 17:30:20 -08:00 committed by GitHub
parent b13a447546
commit 6ddae74054
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -41,6 +41,7 @@ def _lora_shrink_kernel(
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
@ -48,8 +49,16 @@ def _lora_shrink_kernel(
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
pid_m_n = pid_sk_m_n // SPLIT_K
num_pid_in_group = GROUP_SIZE_M * cta_n_num
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
# Column-major ordering within groups for better cache reuse
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
@ -194,6 +203,7 @@ def _lora_shrink(
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
@ -233,6 +243,7 @@ def _lora_shrink(
BLOCK_K,
EVEN_K,
SPLIT_K,
GROUP_SIZE_M,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,

View File

@ -199,6 +199,7 @@ def get_lora_op_configs(
"split_k": 64 if batch < 128 else 8,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,
"num_stages": 2,
"max_nreg": None,
}