mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
Early exit for MoE LoRA kernels (#27131)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
40b69e33e7
commit
294c805f1d
@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
|
||||
int64_t block_size, int num_experts, int max_loras, size_t numel,
|
||||
int max_num_tokens_padded, int max_num_m_blocks,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int topk_num, int32_t* total_tokens_post_pad) {
|
||||
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
|
||||
int32_t* lora_ids) {
|
||||
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
int lora_id = blockIdx.x;
|
||||
int lora_idx = blockIdx.x;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
|
||||
@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras, int64_t max_num_tokens_padded,
|
||||
int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
void moe_lora_align_block_size(
|
||||
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size, int64_t max_loras,
|
||||
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
|
||||
torch::Tensor lora_ids) {
|
||||
const int topk_num = topk_ids.size(1);
|
||||
|
||||
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
|
||||
@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
max_loras, topk_ids.numel(), max_num_tokens_padded,
|
||||
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), topk_num,
|
||||
num_tokens_post_pad.data_ptr<int32_t>());
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
|
||||
});
|
||||
}
|
||||
@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras, int64_t max_num_tokens_padded,
|
||||
int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
void moe_lora_align_block_size(
|
||||
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size, int64_t max_loras,
|
||||
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
|
||||
torch::Tensor lora_ids);
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
|
||||
@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" int max_num_m_blocks, "
|
||||
" Tensor !sorted_token_ids,"
|
||||
" Tensor !experts_ids,"
|
||||
" Tensor !num_tokens_post_pad) -> () ");
|
||||
" Tensor !num_tokens_post_pad,"
|
||||
" Tensor !adapter_enabled,"
|
||||
" Tensor !lora_ids) -> () ");
|
||||
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
|
||||
)
|
||||
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
|
||||
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
|
||||
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
|
||||
|
||||
# call kernel
|
||||
ops.moe_lora_align_block_size(
|
||||
@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
|
||||
config = {
|
||||
@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
|
||||
num_tokens_post_padded,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
|
||||
@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
|
||||
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
|
||||
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
|
||||
|
||||
# call kernel
|
||||
ops.moe_lora_align_block_size(
|
||||
@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
|
||||
# verify values
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@ -28,8 +29,17 @@ EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
EXPECTED_BASE_MODEL_OUTPUT = [
|
||||
"SELECT COUNT(Candidate_ID) FROM candidate",
|
||||
"SELECT COUNT(Candidate_ID) FROM candidate",
|
||||
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
|
||||
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
|
||||
def generate_and_test(
|
||||
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
|
||||
) -> None:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
context="Return the poll resource associated with the most candidates."
|
||||
),
|
||||
]
|
||||
|
||||
lora_request = None
|
||||
if isinstance(lora_id, int):
|
||||
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
elif isinstance(lora_id, list):
|
||||
lora_request = [
|
||||
LoRARequest(str(i), i, lora_path) if i is not None else None
|
||||
for i in lora_id
|
||||
]
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
|
||||
expected_output = (
|
||||
EXPECTED_LORA_OUTPUT[i]
|
||||
if req_lora_id is not None
|
||||
else EXPECTED_BASE_MODEL_OUTPUT[i]
|
||||
)
|
||||
assert generated_texts[i].startswith(expected_output)
|
||||
|
||||
|
||||
def test_olmoe_lora(olmoe_lora_files):
|
||||
@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files):
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
||||
|
||||
|
||||
def test_olmoe_lora_mixed(olmoe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_olmoe_lora_tp2(olmoe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
|
||||
@ -1823,6 +1823,8 @@ def moe_lora_align_block_size(
|
||||
sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._moe_C.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
@ -1835,6 +1837,8 @@ def moe_lora_align_block_size(
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -111,6 +111,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
config["BLOCK_SIZE_M"],
|
||||
self.base_layer.local_num_experts,
|
||||
max_loras,
|
||||
self.adapter_enabled,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
@ -138,6 +139,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
self.adapter_enabled,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
@ -196,6 +198,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
self.adapter_enabled,
|
||||
True,
|
||||
)
|
||||
|
||||
@ -227,6 +230,10 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self.w1_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
@ -313,6 +320,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w3_lora_b_stacked[index] = 0
|
||||
self.w2_lora_a_stacked[index] = 0
|
||||
self.w2_lora_b_stacked[index] = 0
|
||||
self.adapter_enabled[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
@ -322,8 +330,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
"""Overwrites lora tensors at index."""
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
for eid in range(len(lora_a) // 3):
|
||||
w1_lora_a = lora_a[eid * 3]
|
||||
w2_lora_a = lora_a[eid * 3 + 1]
|
||||
|
||||
@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
moe_enabled = tl.load(adapter_enabled + lora_id)
|
||||
if lora_id == -1 or moe_enabled == 0:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
@ -100,12 +107,12 @@ def _fused_moe_lora_kernel(
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
|
||||
# get the expert_id to process curr shard
|
||||
ind = lora_idx * stride_el + pid_m
|
||||
ind = lora_id * stride_el + pid_m
|
||||
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
||||
if expert_id == -1:
|
||||
return
|
||||
@ -119,7 +126,7 @@ def _fused_moe_lora_kernel(
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_idx + offs_token_id
|
||||
token_ind = stride_tl * lora_id + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
||||
)
|
||||
@ -132,7 +139,7 @@ def _fused_moe_lora_kernel(
|
||||
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_idx * stride_bl
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
@ -184,6 +191,8 @@ def _fused_moe_lora(
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
@ -234,7 +243,7 @@ def _fused_moe_lora(
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
lora_intermediate_cache1 = torch.empty(
|
||||
lora_intermediate_cache1 = torch.zeros(
|
||||
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
@ -272,6 +281,8 @@ def _fused_moe_lora(
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
qcurr_hidden_states.stride(0),
|
||||
qcurr_hidden_states.stride(1),
|
||||
w1_lora_a_stacked.stride(0),
|
||||
@ -319,6 +330,8 @@ def _fused_moe_lora(
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
a_intermediate_cache1.stride(0),
|
||||
a_intermediate_cache1.stride(1),
|
||||
w1_lora_b_stacked.stride(0),
|
||||
@ -352,6 +365,8 @@ def _fused_moe_lora_fake(
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
|
||||
@ -456,6 +456,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -479,6 +480,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -305,6 +305,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -331,7 +332,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args(
|
||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
)
|
||||
|
||||
@ -346,6 +347,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
@ -365,11 +368,13 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
@ -381,6 +386,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
num_tokens_post_padded,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user