Early exit for MoE LoRA kernels (#27131)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack 2025-11-03 04:22:17 -08:00 committed by GitHub
parent 40b69e33e7
commit 294c805f1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 123 additions and 34 deletions

View File

@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_id = blockIdx.x;
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
}
}
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
});
}

View File

@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM

View File

@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
)
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
# call kernel
ops.moe_lora_align_block_size(
@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
adapter_enabled,
lora_ids,
)
config = {
@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
num_tokens_post_padded,
max_lora_rank,
top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],

View File

@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
)
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
# call kernel
ops.moe_lora_align_block_size(
@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
# verify values

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
@ -28,8 +29,17 @@ EXPECTED_LORA_OUTPUT = [
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
EXPECTED_BASE_MODEL_OUTPUT = [
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
def generate_and_test(
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
context="Return the poll resource associated with the most candidates."
),
]
lora_request = None
if isinstance(lora_id, int):
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
elif isinstance(lora_id, list):
lora_request = [
LoRARequest(str(i), i, lora_path) if i is not None else None
for i in lora_id
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
expected_output = (
EXPECTED_LORA_OUTPUT[i]
if req_lora_id is not None
else EXPECTED_BASE_MODEL_OUTPUT[i]
)
assert generated_texts[i].startswith(expected_output)
def test_olmoe_lora(olmoe_lora_files):
@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=2)
def test_olmoe_lora_mixed(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
llm = vllm.LLM(

View File

@ -1823,6 +1823,8 @@ def moe_lora_align_block_size(
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor,
) -> None:
torch.ops._moe_C.moe_lora_align_block_size(
topk_ids,
@ -1835,6 +1837,8 @@ def moe_lora_align_block_size(
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)

View File

@ -111,6 +111,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
max_loras,
self.adapter_enabled,
expert_map,
)
@ -138,6 +139,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
max_lora_rank,
top_k,
config,
self.adapter_enabled,
)
result = func(*args, **kwargs)
@ -196,6 +198,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
max_lora_rank,
top_k,
config,
self.adapter_enabled,
True,
)
@ -227,6 +230,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> None:
"""Initializes lora matrices."""
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
@ -313,6 +320,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
@ -322,8 +330,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
self.reset_lora(index)
"""Overwrites lora tensors at index."""
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]

View File

@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
EM,
num_valid_tokens,
num_experts,
lora_ids,
adapter_enabled,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
moe_enabled = tl.load(adapter_enabled + lora_id)
if lora_id == -1 or moe_enabled == 0:
# Early exit for the no-lora case.
return
max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
@ -100,12 +107,12 @@ def _fused_moe_lora_kernel(
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m
ind = lora_id * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
if expert_id == -1:
return
@ -119,7 +126,7 @@ def _fused_moe_lora_kernel(
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
@ -132,7 +139,7 @@ def _fused_moe_lora_kernel(
b_ptrs = (
cur_b_ptr
+ lora_idx * stride_bl
+ lora_id * stride_bl
+ expert_id * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
@ -184,6 +191,8 @@ def _fused_moe_lora(
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int,
block_size_n: int,
block_size_k: int,
@ -234,7 +243,7 @@ def _fused_moe_lora(
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.empty(
lora_intermediate_cache1 = torch.zeros(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=output.dtype,
device=device,
@ -272,6 +281,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
@ -319,6 +330,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
@ -352,6 +365,8 @@ def _fused_moe_lora_fake(
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int,
block_size_n: int,
block_size_k: int,

View File

@ -456,6 +456,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -479,6 +480,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_lora_rank: int,
top_k_num: int,
config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
):
"""

View File

@ -305,6 +305,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -331,7 +332,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args(
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
)
@ -346,6 +347,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
sorted_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
@ -365,11 +368,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank: int,
top_k_num: int,
config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
fused_moe_lora(
y,
x,
@ -381,6 +386,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
num_tokens_post_padded,
max_lora_rank,
top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],