Early exit for MoE LoRA kernels (#27131)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
gnovack 2025-11-03 04:22:17 -08:00 committed by GitHub
parent 40b69e33e7
commit 294c805f1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 123 additions and 34 deletions

View File

@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
int64_t block_size, int num_experts, int max_loras, size_t numel, int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks, int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) { int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x); const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_id = blockIdx.x; int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
} }
} }
void moe_lora_align_block_size(torch::Tensor topk_ids, void moe_lora_align_block_size(
torch::Tensor token_lora_mapping, torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_loras, int64_t max_num_tokens_padded, int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor sorted_token_ids, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor expert_ids, torch::Tensor lora_ids) {
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1); const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
max_loras, topk_ids.numel(), max_num_tokens_padded, max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(), max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num, expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>()); num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
}); });
} }

View File

@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
torch::Tensor expert_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(torch::Tensor topk_ids, void moe_lora_align_block_size(
torch::Tensor token_lora_mapping, torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_loras, int64_t max_num_tokens_padded, int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor sorted_token_ids, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor expert_ids, torch::Tensor lora_ids);
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" int max_num_m_blocks, " " int max_num_m_blocks, "
" Tensor !sorted_token_ids," " Tensor !sorted_token_ids,"
" Tensor !experts_ids," " Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () "); " Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM

View File

@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
) )
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
# call kernel # call kernel
ops.moe_lora_align_block_size( ops.moe_lora_align_block_size(
@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
adapter_enabled,
lora_ids,
) )
config = { config = {
@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
num_tokens_post_padded, num_tokens_post_padded,
max_lora_rank, max_lora_rank,
top_k_num, top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], config["BLOCK_SIZE_K"],

View File

@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
) )
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
# call kernel # call kernel
ops.moe_lora_align_block_size( ops.moe_lora_align_block_size(
@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
adapter_enabled,
lora_ids,
) )
# verify values # verify values

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -28,8 +29,17 @@ EXPECTED_LORA_OUTPUT = [
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
] ]
EXPECTED_BASE_MODEL_OUTPUT = [
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
def generate_and_test(
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
) -> None:
prompts = [ prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"), PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."), PROMPT_TEMPLATE.format(context="Count the number of candidates."),
@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
context="Return the poll resource associated with the most candidates." context="Return the poll resource associated with the most candidates."
), ),
] ]
lora_request = None
if isinstance(lora_id, int):
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
elif isinstance(lora_id, list):
lora_request = [
LoRARequest(str(i), i, lora_path) if i is not None else None
for i in lora_id
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate( outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs. # Print the outputs.
generated_texts: list[str] = [] generated_texts: list[str] = []
for output in outputs: for output in outputs:
@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)): for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
expected_output = (
EXPECTED_LORA_OUTPUT[i]
if req_lora_id is not None
else EXPECTED_BASE_MODEL_OUTPUT[i]
)
assert generated_texts[i].startswith(expected_output)
def test_olmoe_lora(olmoe_lora_files): def test_olmoe_lora(olmoe_lora_files):
@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=2) generate_and_test(llm, olmoe_lora_files, lora_id=2)
def test_olmoe_lora_mixed(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files): def test_olmoe_lora_tp2(olmoe_lora_files):
llm = vllm.LLM( llm = vllm.LLM(

View File

@ -1823,6 +1823,8 @@ def moe_lora_align_block_size(
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor, experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor,
) -> None: ) -> None:
torch.ops._moe_C.moe_lora_align_block_size( torch.ops._moe_C.moe_lora_align_block_size(
topk_ids, topk_ids,
@ -1835,6 +1837,8 @@ def moe_lora_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
adapter_enabled,
lora_ids,
) )

View File

@ -111,6 +111,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts, self.base_layer.local_num_experts,
max_loras, max_loras,
self.adapter_enabled,
expert_map, expert_map,
) )
@ -138,6 +139,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
max_lora_rank, max_lora_rank,
top_k, top_k,
config, config,
self.adapter_enabled,
) )
result = func(*args, **kwargs) result = func(*args, **kwargs)
@ -196,6 +198,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
max_lora_rank, max_lora_rank,
top_k, top_k,
config, config,
self.adapter_enabled,
True, True,
) )
@ -227,6 +230,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
) -> None: ) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self.w1_lora_a_stacked = torch.zeros( self.w1_lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
@ -313,6 +320,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w3_lora_b_stacked[index] = 0 self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0 self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0
self.adapter_enabled[index] = 0
def set_lora( def set_lora(
self, self,
@ -322,8 +330,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: torch.Tensor | None, embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
): ):
self.reset_lora(index)
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3): for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3] w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1] w2_lora_a = lora_a[eid * 3 + 1]

View File

@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
EM, EM,
num_valid_tokens, num_valid_tokens,
num_experts, num_experts,
lora_ids,
adapter_enabled,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1) slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2) lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
moe_enabled = tl.load(adapter_enabled + lora_id)
if lora_id == -1 or moe_enabled == 0:
# Early exit for the no-lora case.
return
max_loras = tl.num_programs(axis=2) max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
@ -100,12 +107,12 @@ def _fused_moe_lora_kernel(
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m pid_n = (pid_m_n % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return return
# get the expert_id to process curr shard # get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m ind = lora_id * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
if expert_id == -1: if expert_id == -1:
return return
@ -119,7 +126,7 @@ def _fused_moe_lora_kernel(
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load( offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
) )
@ -132,7 +139,7 @@ def _fused_moe_lora_kernel(
b_ptrs = ( b_ptrs = (
cur_b_ptr cur_b_ptr
+ lora_idx * stride_bl + lora_id * stride_bl
+ expert_id * stride_be + expert_id * stride_be
+ offs_k[:, None] * stride_bk + offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn + offs_bn[None, :] * stride_bn
@ -184,6 +191,8 @@ def _fused_moe_lora(
num_tokens_post_padded: torch.Tensor, # (max_loras, ) num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int, block_size_m: int,
block_size_n: int, block_size_n: int,
block_size_k: int, block_size_k: int,
@ -234,7 +243,7 @@ def _fused_moe_lora(
num_tokens = M * top_k_num num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2] w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.empty( lora_intermediate_cache1 = torch.zeros(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=output.dtype, dtype=output.dtype,
device=device, device=device,
@ -272,6 +281,8 @@ def _fused_moe_lora(
EM, EM,
num_tokens, num_tokens,
num_experts, num_experts,
lora_ids,
adapter_enabled,
qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1), qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0), w1_lora_a_stacked.stride(0),
@ -319,6 +330,8 @@ def _fused_moe_lora(
EM, EM,
num_tokens, num_tokens,
num_experts, num_experts,
lora_ids,
adapter_enabled,
a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1), a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0), w1_lora_b_stacked.stride(0),
@ -352,6 +365,8 @@ def _fused_moe_lora_fake(
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int, block_size_m: int,
block_size_n: int, block_size_n: int,
block_size_k: int, block_size_k: int,

View File

@ -456,6 +456,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
block_size: int, block_size: int,
num_experts: int, num_experts: int,
max_loras: int, max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -479,6 +480,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
config, config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False, mul_routed_weight=False,
): ):
""" """

View File

@ -305,6 +305,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
block_size: int, block_size: int,
num_experts: int, num_experts: int,
max_loras: int, max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -331,7 +332,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(max_loras), dtype=torch.int32, device=topk_ids.device (max_loras), dtype=torch.int32, device=topk_ids.device
) )
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args( (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens num_tokens
) )
@ -346,6 +347,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
adapter_enabled,
lora_ids,
) )
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
@ -365,11 +368,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
config, config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False, mul_routed_weight=False,
): ):
""" """
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
""" """
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
fused_moe_lora( fused_moe_lora(
y, y,
x, x,
@ -381,6 +386,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
num_tokens_post_padded, num_tokens_post_padded,
max_lora_rank, max_lora_rank,
top_k_num, top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], config["BLOCK_SIZE_K"],