[Feature][Kernel]FusedMoE LoRA (#21229)

Signed-off-by: wuchen <cntryroa@gmail.com>
Signed-off-by: banjuede <lmklhc@163.com>
Signed-off-by: Chen Wu <cntryroa@gmail.com>
Signed-off-by: Danielle Robinson <dmmaddix@amazon.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: wuchen <wuchen@zetyun.com>
Co-authored-by: Nathan Van Gheem <vangheem@gmail.com>
Co-authored-by: banjuede <lmklhc@163.com>
Co-authored-by: Danielle Robinson <dmmaddix@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
Chen Wu 2025-10-21 11:01:37 +08:00 committed by GitHub
parent 3ada34f9cb
commit 5f6cbf60d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 2084 additions and 55 deletions

View File

@ -384,7 +384,12 @@ steps:
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
--ignore=lora/test_chatglm3_tp.py \
--ignore=lora/test_llama_tp.py \
--ignore=lora/test_llm_with_multi_loras.py
--ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4
- label: PyTorch Compilation Unit Tests # 15min
@ -1065,6 +1070,7 @@ steps:
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- label: Weight Loading Multiple GPU Test # 33min

View File

@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@ -0,0 +1,173 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_id = blockIdx.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1);
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
max_num_tokens_padded = round_to_next_multiple_of(
max_num_tokens_padded, static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
});
}

View File

@ -20,6 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@ -33,6 +33,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

View File

@ -230,6 +230,26 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def deepseekv2_lora_files():
return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")
@pytest.fixture(scope="session")
def gptoss20b_lora_files():
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
@pytest.fixture(scope="session")
def qwen3moe_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")
@pytest.fixture(scope="session")
def olmoe_lora_files():
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
@pytest.fixture
def reset_default_device():
"""

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int):
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# return generated_texts
expected_lora_output = [
"I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501
]
for i in range(len(expected_lora_output)):
assert generated_texts[i].startswith(expected_lora_output[i])
def test_deepseekv2_lora(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)
def test_deepseekv2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)
@multi_gpu_test(num_gpus=2)
def test_deepseekv2_tp2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=2,
)
generate_and_test(llm, deepseekv2_lora_files, 2)
@multi_gpu_test(num_gpus=4)
def test_deepseekv2_tp4(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=4,
)
generate_and_test(llm, deepseekv2_lora_files, 2)

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
pass
def round_up(x, base):
return ((x + base - 1) // base) * base
def CEILDIV(x, y):
return (x + y - 1) // y
def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int):
"""
Split `num_tokens` into `num_sequences` sequences.
Each sequence randomly selects 1 LoRA index from [0, max_loras),
and all tokens in that sequence are assigned this LoRA index.
Args:
num_tokens (int): Total number of tokens.
num_sequences (int): Number of sequences to split the tokens into.
max_loras (int): Total number of available LoRA modules.
Returns:
torch.Tensor: 1D tensor of shape [num_tokens], where each value
is the LoRA index assigned to that token.
"""
assert num_sequences > 0 and max_loras > 0
assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences"
# Compute token distribution per sequence (distribute remainder evenly)
tokens_per_seq = num_tokens // num_sequences
remainder = num_tokens % num_sequences
token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32)
start = 0
for seq_idx in range(num_sequences):
# Determine the token range for this sequence
end = start + tokens_per_seq + (1 if seq_idx < remainder else 0)
# Randomly select one LoRA ID for this sequence
lora_id = random.randint(0, max_loras - 1)
# Assign the same LoRA ID to all tokens in this sequence
token_lora_mapping[start:end] = lora_id
start = end
return token_lora_mapping
def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int):
"""
For each token, randomly select `top_k_num` distinct experts out of `num_experts`,
and assign normalized random weights that sum to 1.
Args:
num_tokens (int): Total number of tokens.
num_experts (int): Total number of available experts.
top_k_num (int): Number of experts to select per token.
Returns:
expert_indices (torch.Tensor): shape [num_tokens, top_k_num],
expert index for each token.
expert_weights (torch.Tensor): shape [num_tokens, top_k_num],
normalized weights (sum = 1 per row).
"""
assert top_k_num <= num_experts, "top_k_num must be <= num_experts"
# Randomly select top_k_num distinct experts for each token
expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32)
for i in range(num_tokens):
# Randomly choose unique expert indices
selected = torch.randperm(num_experts)[:top_k_num]
expert_indices[i] = selected
# Generate random weights and normalize along dim=1
expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32)
expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True)
return expert_indices, expert_weights
def sample_data(
num_tokens: int,
num_sequences: int,
max_loras: int,
num_experts: int,
top_k_num: int,
):
topk_ids, topk_weights = assign_experts_to_tokens(
num_tokens, num_experts, top_k_num
)
token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras)
return topk_ids, topk_weights, token_lora_mapping
def use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
num_experts,
block_size,
):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
# init output tensors
sorted_token_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
)
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
# call kernel
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
)
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
mul_routed_weight = False
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
fused_moe_lora(
output,
hidden_states,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_lora_rank,
top_k_num,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
mul_routed_weight,
)
return output
def use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
lora_a_stacked,
lora_b_stacked,
top_k_num,
):
outputs = []
for i in range(hidden_states.shape[0]):
lora_idx = token_lora_mapping[i]
expert_ids = topk_ids[i]
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
tensors = [
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
]
outputs.append(torch.stack(tensors, dim=0))
return torch.stack(outputs, dim=0)
@pytest.mark.parametrize("num_tokens", [100])
@pytest.mark.parametrize("top_k_num", [6, 12])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("max_loras", [4, 6, 16])
@pytest.mark.parametrize("N", [1408])
@pytest.mark.parametrize("K", [2048])
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
@pytest.mark.parametrize("block_size", [16])
def test_fused_moe_lora_kernel(
num_tokens,
top_k_num,
num_experts,
max_loras,
N,
K,
max_lora_rank,
block_size,
):
torch.set_default_device("cuda:0")
current_platform.seed_everything(42)
# the number of randomly generated sentences.
num_sequences = 10
# generate data
topk_ids, topk_weights, token_lora_mapping = sample_data(
num_tokens, num_sequences, max_loras, num_experts, top_k_num
)
# init lora weights
lora_a_stacked = [
torch.rand(
(
max_loras,
num_experts,
max_lora_rank,
K,
),
dtype=torch.bfloat16,
)
]
lora_b_stacked = [
torch.rand(
(
max_loras,
num_experts,
N,
max_lora_rank,
),
dtype=torch.bfloat16,
)
]
hidden_states = torch.rand(
(
num_tokens,
K,
),
dtype=torch.bfloat16,
)
# fused_moe_lora_kernel output
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
use_fused_moe_lora_kernel(
topk_ids,
topk_weights,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_a_stacked,
lora_b_stacked,
hidden_states,
output,
max_loras,
num_experts,
block_size,
)
# pytorch output
output2 = use_torch(
hidden_states,
token_lora_mapping,
topk_ids,
lora_a_stacked,
lora_b_stacked,
top_k_num,
)
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)

52
tests/lora/test_gptoss.py Normal file
View File

@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "openai/gpt-oss-20b"
PROMPT_TEMPLATE = "<begin▁of▁sentence>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
# FIXME: Load gpt-oss adapter
def test_gptoss20b_lora(gptoss20b_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_loras=1,
trust_remote_code=True,
)
expected_lora_output = [
"I am an AI language model developed by OpenAI. "
"I am here to help you with any questions or "
"tasks you may have."
]
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
print(output1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])

View File

@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm import _custom_ops as ops
def round_up(x, base):
return ((x + base - 1) // base) * base
def CEILDIV(x, y):
return (x + y - 1) // y
def sample_data(num_experts, max_loras, num_tokens, topk_num):
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32)
for i in range(num_tokens):
pool = list(range(num_experts))
random.shuffle(pool)
for j in range(topk_num):
topk_ids[i, j] = pool[j]
token_lora_mapping[i] = random.randint(0, max_loras - 1)
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
@pytest.mark.parametrize("topk_num", [6])
@pytest.mark.parametrize("num_experts", [64, 128])
@pytest.mark.parametrize("max_loras", [2, 32])
@pytest.mark.parametrize("block_size", [16])
def test_moe_lora_align_block_size(
num_tokens, topk_num, num_experts, max_loras, block_size
):
# sample data
random.seed(1)
topk_ids, token_lora_mapping = sample_data(
num_experts, max_loras, num_tokens, topk_num
)
# compute paddings
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
# init output tensors
sorted_token_ids = torch.full(
(max_loras * max_num_tokens_padded,),
topk_ids.numel(),
dtype=torch.int32,
device="cuda",
)
expert_ids = torch.full(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
)
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
# call kernel
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
)
# verify values
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
for lora_idx in range(max_loras):
for token_idx in range(sorted_token_ids.size(1)):
block = sorted_token_ids[lora_idx][token_idx]
indices = block[block != topk_ids.numel()]
if indices.numel() > 0:
expert_id = expert_ids[lora_idx][token_idx]
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
if __name__ == "__main__":
pytest.main([__file__])

109
tests/lora/test_olmoe_tp.py Normal file
View File

@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
The People_ID of candidate is the foreign key of People_ID of people.
###Input:
{context}
###Response:""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM candidate",
"SELECT count(*) FROM candidate",
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
PROMPT_TEMPLATE.format(
context="Which poll resource provided the most number of candidate information?" # noqa: E501
),
PROMPT_TEMPLATE.format(
context="Return the poll resource associated with the most candidates."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_olmoe_lora(olmoe_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=4)
def test_olmoe_lora_tp4(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)

View File

@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "Qwen/Qwen3-30B-A3B"
PROMPT_TEMPLATE = """<|im_start|>user
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
The People_ID of candidate is the foreign key of People_ID of people.
###Input:
{context}
###Response:<|im_end|>
<|im_start|>assistant""" # noqa: E501
EXPECTED_LORA_OUTPUT = [
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
PROMPT_TEMPLATE.format(
context="Which poll resource provided the most number of candidate information?" # noqa: E501
),
PROMPT_TEMPLATE.format(
context="Return the poll resource associated with the most candidates."
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_qwen3moe_lora(qwen3moe_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
def test_qwen3moe_lora_tp2(qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=2,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=4)
def test_qwen3moe_lora_tp4(qwen3moe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
tensor_parallel_size=4,
)
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)

View File

@ -1795,6 +1795,28 @@ def batched_moe_align_block_size(
)
def moe_lora_align_block_size(
topk_ids: torch.Tensor,
token_lora_mapping: torch.Tensor,
num_experts: int,
block_size: int,
max_loras: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
torch.ops._moe_C.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
)
def moe_wna16_gemm(
input: torch.Tensor,
output: torch.Tensor,

View File

@ -11,6 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
@ -36,4 +37,5 @@ __all__ = [
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
]

View File

@ -0,0 +1,410 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
_get_config_dtype_str,
mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
self._inject_lora_into_fused_moe()
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
if self.base_layer.quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
quant_config = self.base_layer.quant_config
else:
quant_config = mxfp4_w4a16_moe_quant_config(
w1_bias=self.base_layer.w13_bias,
w2_bias=self.base_layer.w2_bias,
w1_scale=self.base_layer.w13_weight_scale,
w2_scale=self.base_layer.w2_weight_scale,
)
m_fused_moe_fn = (
modular_triton_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
if not quant_config.use_mxfp4_w4a16
else modular_marlin_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
global_num_experts = moe_state_dict["global_num_experts"]
expert_map = moe_state_dict["expert_map"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
config = get_config_func(M)
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
config["BLOCK_SIZE_M"],
global_num_experts,
max_loras,
expert_map,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
w13_lora_a_stacked,
w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
max_loras = moe_state_dict["max_loras"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
config = get_config_func(M)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
[self.w2_lora_a_stacked],
[self.w2_lora_b_stacked],
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
True,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
self.base_layer.quant_method.old_fused_experts = (
self.base_layer.quant_method.fused_experts
)
self.base_layer.quant_method.fused_experts = m_fused_moe_fn
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w1_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.hidden_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.global_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.global_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
self.w1_lora_a_stacked[index] = 0
self.w1_lora_b_stacked[index] = 0
self.w3_lora_a_stacked[index] = 0
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
"""Overwrites lora tensors at index."""
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
self.w1_lora_a_stacked[
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
self.w3_lora_a_stacked[
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
self.w2_lora_b_stacked[
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
self.w1_lora_b_stacked[
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
self.w3_lora_b_stacked[
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return isinstance(source_layer, FusedMoE)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method

View File

@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
@ -23,15 +23,14 @@ from vllm.lora.utils import (
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available
from vllm.utils.cache import LRUCache
@ -60,18 +59,6 @@ def get_lora_id():
return _GLOBAL_LORA_ID
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.warning_once(
"For MoE models, vLLM currently does not support fused MoE LoRA "
"inference. Please ensure that the loaded LoRA model does not "
"contain expert weights."
)
return True
return False
class LoRAModel:
"""A LoRA fine-tuned model."""
@ -229,9 +216,19 @@ class LoRAModel:
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
# Case for expert lora weights
if ".experts" in module_name:
if not any(
module_name.endswith(ele) for ele in expected_lora_modules
):
unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
@ -371,7 +368,7 @@ class LoRAModelManager:
assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}."
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
@ -380,7 +377,6 @@ class LoRAModelManager:
and hasattr(self.model, "get_mm_mapping")
)
self.is_pooling_model = is_pooling_model(self.model)
self.is_moe_model = is_moe_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
@ -431,6 +427,50 @@ class LoRAModelManager:
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
module_lora.optimize()
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
assert gate_up_proj_lora is not None
assert module_lora is not None
down_proj_lora = module_lora
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
num_experts, dim=-1
)
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
num_experts, dim=-1
)
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
lora_a = []
lora_b = []
for i in range(num_experts):
lora_a.append(gate_proj_a[i])
lora_a.append(down_proj_a[i])
lora_a.append(up_proj_a[i])
lora_b.append(gate_proj_b[i])
lora_b.append(down_proj_b[i])
lora_b.append(up_proj_b[i])
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
module.set_lora(
index,
module_lora.lora_a,
@ -486,6 +526,7 @@ class LoRAModelManager:
for module_name, module in self.model.named_modules(remove_duplicate=False):
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
@ -549,7 +590,10 @@ class LoRAModelManager:
new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
assert isinstance(module, BaseLayerWithLoRA), (
f"Module {module_name} must be a BaseLayerWithLoRA instance,"
)
f" got {type(module)}"
self.modules[module_name] = module
def create_dummy_lora(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
@ -9,4 +10,5 @@ __all__ = [
"lora_expand",
"lora_shrink",
"LoRAKernelMeta",
"fused_moe_lora",
]

View File

@ -0,0 +1,350 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.utils.torch_utils import direct_register_custom_op
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device)
_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
@triton.jit
def _fused_moe_lora_kernel(
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
num_experts,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bl,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_tl,
stride_el,
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
slice_a_size: tl.constexpr,
slice_c_size: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
max_loras = tl.num_programs(axis=2)
# calculate pid_m,pid_n
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind)
if expert_id == -1:
return
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
cur_b_ptr
+ lora_idx * stride_bl
+ expert_id * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(tl.bfloat16)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@torch.inference_mode()
def _fused_moe_lora(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
mul_routed_weight: bool = False,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
sorted_token_ids.dim()
== expert_ids.dim()
== topk_weights.dim()
== qcurr_hidden_states.dim()
== 2
)
assert (
sorted_token_ids.shape[0]
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
}
w1_lora_a_stacked = lora_a_stacked[0]
w1_lora_b_stacked = lora_b_stacked[0]
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
EM = sorted_token_ids.shape[1]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
lora_intermediate_cache1 = torch.zeros(
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
dtype=torch.bfloat16,
device=device,
)
# slices
a_intermediate_size = num_slices * M * top_k_num * max_lora_rank
a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view(
num_slices, M, top_k_num, max_lora_rank
)
b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view(
num_slices, M, top_k_num, w1_output_dim_size
)
b_ptr = _get_ptr(lora_a_stacked, device)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
b_ptr,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=1,
num_slice_c=num_slices,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
**config,
)
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
N = w1_output_dim_size
# a_intermediate_cache1 = a_intermediate_cache1.view(
# M, -1, a_intermediate_cache1.shape[3]
# )
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
b_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2),
b_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=num_slices,
num_slice_c=num_slices,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
**config,
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
def _fused_moe_lora_fake(
output: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
mul_routed_weight: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="fused_moe_lora",
op_func=_fused_moe_lora,
mutates_args=["output"],
fake_impl=_fused_moe_lora_fake,
)
fused_moe_lora = torch.ops.vllm.fused_moe_lora
except AttributeError:
fused_moe_lora = _fused_moe_lora

View File

@ -448,3 +448,42 @@ class PunicaWrapperBase(PunicaWrapperABC):
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
config,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of
Mixture-of-Experts (MoE) layer.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@ -12,10 +12,18 @@ from typing import final
import torch
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils import round_up
if HAS_TRITON:
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
lora_expand,
lora_shrink,
)
from vllm import _custom_ops as ops
from .punica_base import PunicaWrapperBase
@ -289,3 +297,91 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs=True,
)
y = y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args(
num_tokens
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
config,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_lora_rank,
top_k_num,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
mul_routed_weight,
)

View File

@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
@ -35,7 +36,9 @@ from vllm.lora.layers import (
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
}
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer(
layer: nn.Module,
max_loras: int,
@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules)
@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str:
return lora_path
return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)

View File

@ -94,7 +94,8 @@ class WorkerLoRAManager:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path)

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ."""
from collections.abc import Callable
import torch
import vllm._custom_ops as ops
@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -36,12 +56,15 @@ def _fused_marlin_moe(
num_topk: int,
quant_type: ScalarType,
apply_router_weight_on_input: bool,
activation: str,
expert_map: torch.Tensor | None,
block_size_m: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
@ -118,20 +141,9 @@ def _fused_marlin_moe(
is_zp_float=False,
)
if activation == "silu":
torch.ops._C.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
activation_func(
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
)
if output is None:
output = intermediate_cache3
@ -185,7 +197,11 @@ def fused_marlin_moe(
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
activation: str | None = "silu",
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
@ -290,12 +306,13 @@ def fused_marlin_moe(
num_topk=topk,
quant_type=quant_type,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
expert_map=expert_map,
block_size_m=block_size_m,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
@ -317,7 +334,10 @@ def fused_marlin_moe(
else:
output = torch.empty_like(hidden_states)
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
else:
return moe_sum(moe_output, output)
def batched_fused_marlin_moe(
@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
activation_func=self.activation,
moe_sum=self.moe_sum,
expert_map=expert_map,
output=output,
# Workspaces are swapped in workspace_shapes() to account for proper
@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase):
intermediate_cache2=workspace13,
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(

View File

@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
B_bias=self.w2_bias,
)
ops.moe_sum(intermediate_cache3, output)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)
def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig,
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config),
shared_experts,
)

View File

@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")

View File

@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)

View File

@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .interfaces import SupportsEagle3, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@ -627,7 +627,7 @@ class GptOssModel(nn.Module):
)
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,

View File

@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
@ -349,8 +349,6 @@ class OlmoeModel(nn.Module):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
@ -433,17 +431,13 @@ class OlmoeModel(nn.Module):
return loaded_params
class OlmoeForCausalLM(nn.Module, SupportsPP):
class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
]
}
def __init__(