mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 12:54:29 +08:00
[Feature][Kernel]FusedMoE LoRA (#21229)
Signed-off-by: wuchen <cntryroa@gmail.com> Signed-off-by: banjuede <lmklhc@163.com> Signed-off-by: Chen Wu <cntryroa@gmail.com> Signed-off-by: Danielle Robinson <dmmaddix@amazon.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: bk-201 <joy25810@foxmail.com> Co-authored-by: wuchen <wuchen@zetyun.com> Co-authored-by: Nathan Van Gheem <vangheem@gmail.com> Co-authored-by: banjuede <lmklhc@163.com> Co-authored-by: Danielle Robinson <dmmaddix@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
3ada34f9cb
commit
5f6cbf60d6
@ -384,7 +384,12 @@ steps:
|
||||
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
|
||||
--ignore=lora/test_chatglm3_tp.py \
|
||||
--ignore=lora/test_llama_tp.py \
|
||||
--ignore=lora/test_llm_with_multi_loras.py
|
||||
--ignore=lora/test_llm_with_multi_loras.py \
|
||||
--ignore=lora/test_olmoe_tp.py \
|
||||
--ignore=lora/test_deepseekv2_tp.py \
|
||||
--ignore=lora/test_gptoss.py \
|
||||
--ignore=lora/test_qwen3moe_tp.py
|
||||
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests # 15min
|
||||
@ -1065,6 +1070,7 @@ steps:
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
|
||||
@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/moe_lora_align_sum_kernels.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
173
csrc/moe/moe_lora_align_sum_kernels.cu
Normal file
173
csrc/moe/moe_lora_align_sum_kernels.cu
Normal file
@ -0,0 +1,173 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
#include "core/math.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
int32_t col) {
|
||||
return row * total_col + col;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO: Refactor common parts with moe_align_sum_kernels
|
||||
template <typename scalar_t, typename token_cnts_t>
|
||||
__global__ void moe_lora_align_sum_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
|
||||
int64_t block_size, int num_experts, int max_loras, size_t numel,
|
||||
int max_num_tokens_padded, int max_num_m_blocks,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int topk_num, int32_t* total_tokens_post_pad) {
|
||||
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
int lora_id = blockIdx.x;
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
|
||||
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
|
||||
}
|
||||
|
||||
// Initialize expert_ids with -1
|
||||
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
|
||||
expert_ids[lora_id * max_num_m_blocks + it] = -1;
|
||||
}
|
||||
|
||||
// Initialize total_tokens_post_pad with 0
|
||||
if (threadIdx.x == 0) {
|
||||
total_tokens_post_pad[lora_id] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
}
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int mask = token_lora_mapping[i / topk_num] == lora_id;
|
||||
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
|
||||
tokens_cnts[idx] += mask;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i - 1] +
|
||||
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding
|
||||
* blocks and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
|
||||
threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and
|
||||
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||
* processed by the expert with expert_id within the current thread's token
|
||||
* shard.
|
||||
*/
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||
cumsum[expert_id];
|
||||
|
||||
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
|
||||
atomicAdd(
|
||||
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
|
||||
(i - numel) * mask);
|
||||
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
|
||||
}
|
||||
}
|
||||
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const int topk_num = topk_ids.size(1);
|
||||
|
||||
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
|
||||
|
||||
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
|
||||
max_num_tokens_padded = round_to_next_multiple_of(
|
||||
max_num_tokens_padded, static_cast<int>(block_size));
|
||||
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
|
||||
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
cudaDeviceGetAttribute(&device_max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
|
||||
TORCH_CHECK(num_thread <= 1024,
|
||||
"num_thread must be less than 1024, "
|
||||
"and fallback is not implemented yet.");
|
||||
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
|
||||
if (shared_mem > device_max_shared_mem) {
|
||||
TORCH_CHECK(false,
|
||||
"Shared memory usage exceeds device limit, and global memory "
|
||||
"fallback is not implemented yet.");
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
|
||||
dim3 blockDim(num_thread);
|
||||
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
|
||||
max_loras, topk_ids.numel(), max_num_tokens_padded,
|
||||
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), topk_num,
|
||||
num_tokens_post_pad.data_ptr<int32_t>());
|
||||
});
|
||||
}
|
||||
@ -20,6 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void moe_lora_align_block_size(torch::Tensor topk_ids,
|
||||
torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size,
|
||||
int64_t max_loras,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
|
||||
@ -33,6 +33,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("batched_moe_align_block_size", torch::kCUDA,
|
||||
&batched_moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
m.def(
|
||||
"moe_lora_align_block_size(Tensor topk_ids,"
|
||||
" Tensor token_lora_mapping,"
|
||||
" int num_experts,"
|
||||
" int block_size, int max_loras, "
|
||||
" Tensor !sorted_token_ids,"
|
||||
" Tensor !experts_ids,"
|
||||
" Tensor !num_tokens_post_pad) -> () ");
|
||||
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
|
||||
@ -230,6 +230,26 @@ def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def deepseekv2_lora_files():
|
||||
return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gptoss20b_lora_files():
|
||||
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen3moe_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def olmoe_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
97
tests/lora/test_deepseekv2_tp.py
Normal file
97
tests/lora/test_deepseekv2_tp.py
Normal file
@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
|
||||
|
||||
|
||||
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int):
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="Who are you?"),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
# return generated_texts
|
||||
expected_lora_output = [
|
||||
"I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501
|
||||
]
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert generated_texts[i].startswith(expected_lora_output[i])
|
||||
|
||||
|
||||
def test_deepseekv2_lora(deepseekv2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, deepseekv2_lora_files, 1)
|
||||
|
||||
|
||||
def test_deepseekv2(deepseekv2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
generate_and_test(llm, deepseekv2_lora_files, 1)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_deepseekv2_tp2(deepseekv2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
generate_and_test(llm, deepseekv2_lora_files, 2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_deepseekv2_tp4(deepseekv2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
generate_and_test(llm, deepseekv2_lora_files, 2)
|
||||
287
tests/lora/test_fused_moe_lora_kernel.py
Normal file
287
tests/lora/test_fused_moe_lora_kernel.py
Normal file
@ -0,0 +1,287 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.lora.ops.triton_ops import fused_moe_lora
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_device(reset_default_device):
|
||||
pass
|
||||
|
||||
|
||||
def round_up(x, base):
|
||||
return ((x + base - 1) // base) * base
|
||||
|
||||
|
||||
def CEILDIV(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int):
|
||||
"""
|
||||
Split `num_tokens` into `num_sequences` sequences.
|
||||
Each sequence randomly selects 1 LoRA index from [0, max_loras),
|
||||
and all tokens in that sequence are assigned this LoRA index.
|
||||
|
||||
Args:
|
||||
num_tokens (int): Total number of tokens.
|
||||
num_sequences (int): Number of sequences to split the tokens into.
|
||||
max_loras (int): Total number of available LoRA modules.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 1D tensor of shape [num_tokens], where each value
|
||||
is the LoRA index assigned to that token.
|
||||
"""
|
||||
assert num_sequences > 0 and max_loras > 0
|
||||
assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences"
|
||||
|
||||
# Compute token distribution per sequence (distribute remainder evenly)
|
||||
tokens_per_seq = num_tokens // num_sequences
|
||||
remainder = num_tokens % num_sequences
|
||||
|
||||
token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32)
|
||||
|
||||
start = 0
|
||||
for seq_idx in range(num_sequences):
|
||||
# Determine the token range for this sequence
|
||||
end = start + tokens_per_seq + (1 if seq_idx < remainder else 0)
|
||||
|
||||
# Randomly select one LoRA ID for this sequence
|
||||
lora_id = random.randint(0, max_loras - 1)
|
||||
|
||||
# Assign the same LoRA ID to all tokens in this sequence
|
||||
token_lora_mapping[start:end] = lora_id
|
||||
|
||||
start = end
|
||||
|
||||
return token_lora_mapping
|
||||
|
||||
|
||||
def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int):
|
||||
"""
|
||||
For each token, randomly select `top_k_num` distinct experts out of `num_experts`,
|
||||
and assign normalized random weights that sum to 1.
|
||||
|
||||
Args:
|
||||
num_tokens (int): Total number of tokens.
|
||||
num_experts (int): Total number of available experts.
|
||||
top_k_num (int): Number of experts to select per token.
|
||||
|
||||
Returns:
|
||||
expert_indices (torch.Tensor): shape [num_tokens, top_k_num],
|
||||
expert index for each token.
|
||||
expert_weights (torch.Tensor): shape [num_tokens, top_k_num],
|
||||
normalized weights (sum = 1 per row).
|
||||
"""
|
||||
assert top_k_num <= num_experts, "top_k_num must be <= num_experts"
|
||||
|
||||
# Randomly select top_k_num distinct experts for each token
|
||||
expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32)
|
||||
for i in range(num_tokens):
|
||||
# Randomly choose unique expert indices
|
||||
selected = torch.randperm(num_experts)[:top_k_num]
|
||||
expert_indices[i] = selected
|
||||
|
||||
# Generate random weights and normalize along dim=1
|
||||
expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32)
|
||||
expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True)
|
||||
|
||||
return expert_indices, expert_weights
|
||||
|
||||
|
||||
def sample_data(
|
||||
num_tokens: int,
|
||||
num_sequences: int,
|
||||
max_loras: int,
|
||||
num_experts: int,
|
||||
top_k_num: int,
|
||||
):
|
||||
topk_ids, topk_weights = assign_experts_to_tokens(
|
||||
num_tokens, num_experts, top_k_num
|
||||
)
|
||||
token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras)
|
||||
return topk_ids, topk_weights, token_lora_mapping
|
||||
|
||||
|
||||
def use_fused_moe_lora_kernel(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
hidden_states,
|
||||
output,
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
):
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
|
||||
|
||||
# init output tensors
|
||||
sorted_token_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
|
||||
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
|
||||
|
||||
# call kernel
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
mul_routed_weight = False
|
||||
expert_ids = expert_ids.view(max_loras, -1)
|
||||
sorted_token_ids = sorted_token_ids.view(max_loras, -1)
|
||||
|
||||
fused_moe_lora(
|
||||
output,
|
||||
hidden_states,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def use_torch(
|
||||
hidden_states,
|
||||
token_lora_mapping,
|
||||
topk_ids,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
):
|
||||
outputs = []
|
||||
for i in range(hidden_states.shape[0]):
|
||||
lora_idx = token_lora_mapping[i]
|
||||
expert_ids = topk_ids[i]
|
||||
lora_a = lora_a_stacked[0][lora_idx][expert_ids]
|
||||
lora_b = lora_b_stacked[0][lora_idx][expert_ids]
|
||||
tensors = [
|
||||
hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num)
|
||||
]
|
||||
outputs.append(torch.stack(tensors, dim=0))
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [100])
|
||||
@pytest.mark.parametrize("top_k_num", [6, 12])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("max_loras", [4, 6, 16])
|
||||
@pytest.mark.parametrize("N", [1408])
|
||||
@pytest.mark.parametrize("K", [2048])
|
||||
@pytest.mark.parametrize("max_lora_rank", [16, 32, 64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_fused_moe_lora_kernel(
|
||||
num_tokens,
|
||||
top_k_num,
|
||||
num_experts,
|
||||
max_loras,
|
||||
N,
|
||||
K,
|
||||
max_lora_rank,
|
||||
block_size,
|
||||
):
|
||||
torch.set_default_device("cuda:0")
|
||||
current_platform.seed_everything(42)
|
||||
# the number of randomly generated sentences.
|
||||
num_sequences = 10
|
||||
# generate data
|
||||
topk_ids, topk_weights, token_lora_mapping = sample_data(
|
||||
num_tokens, num_sequences, max_loras, num_experts, top_k_num
|
||||
)
|
||||
|
||||
# init lora weights
|
||||
lora_a_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
max_lora_rank,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
]
|
||||
lora_b_stacked = [
|
||||
torch.rand(
|
||||
(
|
||||
max_loras,
|
||||
num_experts,
|
||||
N,
|
||||
max_lora_rank,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
]
|
||||
hidden_states = torch.rand(
|
||||
(
|
||||
num_tokens,
|
||||
K,
|
||||
),
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# fused_moe_lora_kernel output
|
||||
output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16)
|
||||
use_fused_moe_lora_kernel(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
hidden_states,
|
||||
output,
|
||||
max_loras,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
# pytorch output
|
||||
output2 = use_torch(
|
||||
hidden_states,
|
||||
token_lora_mapping,
|
||||
topk_ids,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
top_k_num,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1)
|
||||
52
tests/lora/test_gptoss.py
Normal file
52
tests/lora/test_gptoss.py
Normal file
@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "openai/gpt-oss-20b"
|
||||
|
||||
PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="Who are you?"),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
# FIXME: Load gpt-oss adapter
|
||||
def test_gptoss20b_lora(gptoss20b_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
expected_lora_output = [
|
||||
"I am an AI language model developed by OpenAI. "
|
||||
"I am here to help you with any questions or "
|
||||
"tasks you may have."
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1)
|
||||
print(output1)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output1[i].startswith(expected_lora_output[i])
|
||||
90
tests/lora/test_moe_lora_align_sum.py
Normal file
90
tests/lora/test_moe_lora_align_sum.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def round_up(x, base):
|
||||
return ((x + base - 1) // base) * base
|
||||
|
||||
|
||||
def CEILDIV(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def sample_data(num_experts, max_loras, num_tokens, topk_num):
|
||||
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
|
||||
token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32)
|
||||
|
||||
for i in range(num_tokens):
|
||||
pool = list(range(num_experts))
|
||||
random.shuffle(pool)
|
||||
for j in range(topk_num):
|
||||
topk_ids[i, j] = pool[j]
|
||||
token_lora_mapping[i] = random.randint(0, max_loras - 1)
|
||||
|
||||
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
|
||||
@pytest.mark.parametrize("topk_num", [6])
|
||||
@pytest.mark.parametrize("num_experts", [64, 128])
|
||||
@pytest.mark.parametrize("max_loras", [2, 32])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_moe_lora_align_block_size(
|
||||
num_tokens, topk_num, num_experts, max_loras, block_size
|
||||
):
|
||||
# sample data
|
||||
random.seed(1)
|
||||
topk_ids, token_lora_mapping = sample_data(
|
||||
num_experts, max_loras, num_tokens, topk_num
|
||||
)
|
||||
|
||||
# compute paddings
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
|
||||
|
||||
# init output tensors
|
||||
sorted_token_ids = torch.full(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
topk_ids.numel(),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
expert_ids = torch.full(
|
||||
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
|
||||
|
||||
# call kernel
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
|
||||
# verify values
|
||||
expert_ids = expert_ids.view(max_loras, -1)
|
||||
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
|
||||
|
||||
for lora_idx in range(max_loras):
|
||||
for token_idx in range(sorted_token_ids.size(1)):
|
||||
block = sorted_token_ids[lora_idx][token_idx]
|
||||
indices = block[block != topk_ids.numel()]
|
||||
if indices.numel() > 0:
|
||||
expert_id = expert_ids[lora_idx][token_idx]
|
||||
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
109
tests/lora/test_olmoe_tp.py
Normal file
109
tests/lora/test_olmoe_tp.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct"
|
||||
|
||||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
|
||||
"
|
||||
##Instruction:
|
||||
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
|
||||
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
|
||||
The People_ID of candidate is the foreign key of People_ID of people.
|
||||
|
||||
|
||||
###Input:
|
||||
{context}
|
||||
|
||||
###Response:""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
"SELECT count(*) FROM candidate",
|
||||
"SELECT count(*) FROM candidate",
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Which poll resource provided the most number of candidate information?" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Return the poll resource associated with the most candidates."
|
||||
),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
|
||||
|
||||
def test_olmoe_lora(olmoe_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_olmoe_lora_tp2(olmoe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_olmoe_lora_tp4(olmoe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, olmoe_lora_files, lora_id=2)
|
||||
111
tests/lora/test_qwen3moe_tp.py
Normal file
111
tests/lora/test_qwen3moe_tp.py
Normal file
@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen3-30B-A3B"
|
||||
|
||||
PROMPT_TEMPLATE = """<|im_start|>user
|
||||
I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
|
||||
"
|
||||
##Instruction:
|
||||
candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key.
|
||||
Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key.
|
||||
The People_ID of candidate is the foreign key of People_ID of people.
|
||||
|
||||
|
||||
###Input:
|
||||
{context}
|
||||
|
||||
###Response:<|im_end|>
|
||||
<|im_start|>assistant""" # noqa: E501
|
||||
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
|
||||
"<think>\n\n</think>\n\nSELECT count(*) FROM candidate",
|
||||
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
"<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
|
||||
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Which poll resource provided the most number of candidate information?" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
context="Return the poll resource associated with the most candidates."
|
||||
),
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: list[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
|
||||
|
||||
def test_qwen3moe_lora(qwen3moe_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_qwen3moe_lora_tp2(qwen3moe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_qwen3moe_lora_tp4(qwen3moe_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=1)
|
||||
generate_and_test(llm, qwen3moe_lora_files, lora_id=2)
|
||||
@ -1795,6 +1795,28 @@ def batched_moe_align_block_size(
|
||||
)
|
||||
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
max_loras: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._moe_C.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
|
||||
|
||||
def moe_wna16_gemm(
|
||||
input: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
@ -36,4 +37,5 @@ __all__ = [
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
]
|
||||
|
||||
410
vllm/lora/layers/fused_moe.py
Normal file
410
vllm/lora/layers/fused_moe.py
Normal file
@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
_get_config_dtype_str,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
modular_marlin_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: FusedMoE) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = base_layer.w2_weight.device
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
def _inject_lora_into_fused_moe(self):
|
||||
moe_state_dict = {}
|
||||
top_k = self.base_layer.top_k
|
||||
|
||||
if self.base_layer.quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
elif not isinstance(self.base_layer.quant_config, Mxfp4Config):
|
||||
quant_config = self.base_layer.quant_config
|
||||
else:
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=self.base_layer.w13_bias,
|
||||
w2_bias=self.base_layer.w2_bias,
|
||||
w1_scale=self.base_layer.w13_weight_scale,
|
||||
w2_scale=self.base_layer.w2_weight_scale,
|
||||
)
|
||||
|
||||
m_fused_moe_fn = (
|
||||
modular_triton_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
if not quant_config.use_mxfp4_w4a16
|
||||
else modular_marlin_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||
moe_state_dict["global_num_experts"] = kwargs["global_num_experts"]
|
||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||
"apply_router_weight_on_input"
|
||||
]
|
||||
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def act_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
_, output, input = args
|
||||
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||
global_num_experts = moe_state_dict["global_num_experts"]
|
||||
expert_map = moe_state_dict["expert_map"]
|
||||
max_loras = moe_state_dict["max_loras"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
(
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
) = self.punica_wrapper.moe_lora_align_block_size(
|
||||
curr_topk_ids,
|
||||
num_tokens,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
max_loras,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
|
||||
moe_state_dict["expert_ids_lora"] = expert_ids_lora
|
||||
moe_state_dict["num_tokens_post_padded_lora"] = (
|
||||
num_tokens_post_padded_lora
|
||||
)
|
||||
|
||||
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
|
||||
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
|
||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
input.view(-1, top_k, input.shape[-1]),
|
||||
hidden_states,
|
||||
w13_lora_a_stacked,
|
||||
w13_lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
moe_state_dict["intermediate_cache2"] = output
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def moe_sum_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
max_loras = moe_state_dict["max_loras"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
|
||||
expert_ids_lora = moe_state_dict["expert_ids_lora"]
|
||||
num_tokens_post_padded_lora = moe_state_dict[
|
||||
"num_tokens_post_padded_lora"
|
||||
]
|
||||
|
||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||
intermediate_cache3 = args[0]
|
||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
intermediate_cache3,
|
||||
intermediate_cache2,
|
||||
[self.w2_lora_a_stacked],
|
||||
[self.w2_lora_b_stacked],
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
True,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
fused_experts = m_fused_moe_fn.fused_experts
|
||||
|
||||
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
||||
fused_experts.activation = act_decorator(
|
||||
self.base_layer, fused_experts.activation
|
||||
)
|
||||
fused_experts.moe_sum = moe_sum_decorator(
|
||||
self.base_layer, fused_experts.moe_sum
|
||||
)
|
||||
|
||||
self.base_layer.quant_method.old_fused_experts = (
|
||||
self.base_layer.quant_method.fused_experts
|
||||
)
|
||||
self.base_layer.quant_method.fused_experts = m_fused_moe_fn
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
|
||||
self.w1_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w1_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.w2_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w2_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.hidden_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.w3_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w3_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.global_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
|
||||
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
|
||||
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
|
||||
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
|
||||
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
|
||||
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
|
||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||
# to create a dummy LoRA weights.
|
||||
self.lora_a_stacked = []
|
||||
self.lora_b_stacked = []
|
||||
for lora_id in range(max_loras):
|
||||
for experts_id in range(self.base_layer.global_num_experts):
|
||||
# gate_proj,down_proj,up_proj
|
||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
|
||||
|
||||
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
|
||||
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
|
||||
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
self.w1_lora_a_stacked[index] = 0
|
||||
self.w1_lora_b_stacked[index] = 0
|
||||
self.w3_lora_a_stacked[index] = 0
|
||||
self.w3_lora_b_stacked[index] = 0
|
||||
self.w2_lora_a_stacked[index] = 0
|
||||
self.w2_lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
for eid in range(len(lora_a) // 3):
|
||||
w1_lora_a = lora_a[eid * 3]
|
||||
w2_lora_a = lora_a[eid * 3 + 1]
|
||||
w3_lora_a = lora_a[eid * 3 + 2]
|
||||
w1_lora_b = lora_b[eid * 3]
|
||||
w2_lora_b = lora_b[eid * 3 + 1]
|
||||
w3_lora_b = lora_b[eid * 3 + 2]
|
||||
|
||||
if self.tp_size > 1:
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
|
||||
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
||||
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
||||
|
||||
self.w1_lora_a_stacked[
|
||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||
].copy_(w1_lora_a, non_blocking=True)
|
||||
|
||||
self.w3_lora_a_stacked[
|
||||
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
||||
].copy_(w3_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[
|
||||
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
||||
].copy_(w2_lora_b, non_blocking=True)
|
||||
|
||||
self.w1_lora_b_stacked[
|
||||
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
||||
].copy_(w1_lora_b, non_blocking=True)
|
||||
self.w3_lora_b_stacked[
|
||||
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
||||
].copy_(w3_lora_b, non_blocking=True)
|
||||
self.w2_lora_a_stacked[
|
||||
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
||||
].copy_(w2_lora_a, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# return type(source_layer) is FusedMoE
|
||||
return isinstance(source_layer, FusedMoE)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.base_layer.forward(*args, **kwargs)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
||||
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _shared_experts(self):
|
||||
return self.base_layer._shared_experts
|
||||
|
||||
@property
|
||||
def quant_method(self):
|
||||
return self.base_layer.quant_method
|
||||
@ -13,7 +13,7 @@ from torch import nn
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
@ -23,15 +23,14 @@ from vllm.lora.utils import (
|
||||
get_supported_lora_modules,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name,
|
||||
process_packed_modules_mapping,
|
||||
replace_submodule,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||
from vllm.model_executor.utils import get_packed_modules_mapping
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.cache import LRUCache
|
||||
|
||||
@ -60,18 +59,6 @@ def get_lora_id():
|
||||
return _GLOBAL_LORA_ID
|
||||
|
||||
|
||||
def is_moe_model(model: nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers and warns the user."""
|
||||
if any(isinstance(module, FusedMoE) for module in model.modules()):
|
||||
logger.warning_once(
|
||||
"For MoE models, vLLM currently does not support fused MoE LoRA "
|
||||
"inference. Please ensure that the loaded LoRA model does not "
|
||||
"contain expert weights."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LoRAModel:
|
||||
"""A LoRA fine-tuned model."""
|
||||
|
||||
@ -229,9 +216,19 @@ class LoRAModel:
|
||||
def check_unexpected_modules(modules: dict):
|
||||
for lora_module in modules.keys(): # noqa
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||
part_name = module_name.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
# Handle FSDP file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
if "base_layer" in lora_module:
|
||||
continue
|
||||
# Case for expert lora weights
|
||||
if ".experts" in module_name:
|
||||
if not any(
|
||||
module_name.endswith(ele) for ele in expected_lora_modules
|
||||
):
|
||||
unexpected_modules.append(module_name)
|
||||
elif module_name.split(".")[-1] not in expected_lora_modules:
|
||||
unexpected_modules.append(module_name)
|
||||
|
||||
if unexpected_modules:
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
@ -371,7 +368,7 @@ class LoRAModelManager:
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f" {self.model.__class__.__name__}."
|
||||
|
||||
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
|
||||
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
@ -380,7 +377,6 @@ class LoRAModelManager:
|
||||
and hasattr(self.model, "get_mm_mapping")
|
||||
)
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.is_moe_model = is_moe_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
@ -431,6 +427,50 @@ class LoRAModelManager:
|
||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
|
||||
module_lora.lora_a
|
||||
):
|
||||
# Handle FSDP file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
gate_up_proj_lora = self._get_lora_layer_weights(
|
||||
lora_model, module_name + ".base_layer"
|
||||
)
|
||||
|
||||
assert gate_up_proj_lora is not None
|
||||
assert module_lora is not None
|
||||
|
||||
down_proj_lora = module_lora
|
||||
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
|
||||
|
||||
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
|
||||
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
|
||||
num_experts, dim=-1
|
||||
)
|
||||
|
||||
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
|
||||
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
|
||||
|
||||
lora_a = []
|
||||
lora_b = []
|
||||
for i in range(num_experts):
|
||||
lora_a.append(gate_proj_a[i])
|
||||
lora_a.append(down_proj_a[i])
|
||||
lora_a.append(up_proj_a[i])
|
||||
|
||||
lora_b.append(gate_proj_b[i])
|
||||
lora_b.append(down_proj_b[i])
|
||||
lora_b.append(up_proj_b[i])
|
||||
|
||||
module_lora.lora_a = lora_a
|
||||
module_lora.lora_b = lora_b
|
||||
|
||||
module.set_lora(
|
||||
index,
|
||||
module_lora.lora_a,
|
||||
@ -486,6 +526,7 @@ class LoRAModelManager:
|
||||
for module_name, module in self.model.named_modules(remove_duplicate=False):
|
||||
if isinstance(module, PPMissingLayer):
|
||||
continue
|
||||
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
@ -549,7 +590,10 @@ class LoRAModelManager:
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
assert isinstance(module, BaseLayerWithLoRA), (
|
||||
f"Module {module_name} must be a BaseLayerWithLoRA instance,"
|
||||
)
|
||||
f" got {type(module)}"
|
||||
self.modules[module_name] = module
|
||||
|
||||
def create_dummy_lora(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora
|
||||
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
@ -9,4 +10,5 @@ __all__ = [
|
||||
"lora_expand",
|
||||
"lora_shrink",
|
||||
"LoRAKernelMeta",
|
||||
"fused_moe_lora",
|
||||
]
|
||||
|
||||
350
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Normal file
350
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Normal file
@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
"""
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||
|
||||
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
|
||||
return ptr_tensor
|
||||
|
||||
tensor_ptrs = []
|
||||
for lora_weight in lora_weights:
|
||||
tensor_ptrs.append(lora_weight.data_ptr())
|
||||
ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
|
||||
_LORA_PTR_DICT[key] = ptr_tensor
|
||||
return _LORA_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_moe_lora_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
num_experts,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bl,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_tl,
|
||||
stride_el,
|
||||
# Meta-parameters
|
||||
num_slice_a: tl.constexpr,
|
||||
num_slice_c: tl.constexpr,
|
||||
slice_a_size: tl.constexpr,
|
||||
slice_c_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
|
||||
# get the expert_id to process curr shard
|
||||
ind = lora_idx * stride_el + pid_m
|
||||
expert_id = tl.load(expert_ids_ptr + ind)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
# get a_ptr,b_ptr,c_ptr
|
||||
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_idx + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
|
||||
)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
# get a_ptrs,b_ptrs
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_idx * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(tl.bfloat16)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
|
||||
lora_a_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, N, max_lora_rank,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
assert (
|
||||
sorted_token_ids.dim()
|
||||
== expert_ids.dim()
|
||||
== topk_weights.dim()
|
||||
== qcurr_hidden_states.dim()
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
sorted_token_ids.shape[0]
|
||||
== expert_ids.shape[0]
|
||||
== num_tokens_post_padded.shape[0]
|
||||
)
|
||||
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
|
||||
device = qcurr_hidden_states.device
|
||||
num_slices = len(lora_a_stacked)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
}
|
||||
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
num_experts = lora_a_stacked[0].shape[1]
|
||||
|
||||
N = max_lora_rank
|
||||
M = topk_weights.shape[0]
|
||||
EM = sorted_token_ids.shape[1]
|
||||
K = qcurr_hidden_states.shape[1]
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
lora_intermediate_cache1 = torch.zeros(
|
||||
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# slices
|
||||
a_intermediate_size = num_slices * M * top_k_num * max_lora_rank
|
||||
a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view(
|
||||
num_slices, M, top_k_num, max_lora_rank
|
||||
)
|
||||
b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view(
|
||||
num_slices, M, top_k_num, w1_output_dim_size
|
||||
)
|
||||
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_a_stacked),
|
||||
lora_a_stacked[0].shape[0],
|
||||
)
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
b_ptr,
|
||||
a_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
qcurr_hidden_states.stride(0),
|
||||
qcurr_hidden_states.stride(1),
|
||||
w1_lora_a_stacked.stride(0),
|
||||
w1_lora_a_stacked.stride(1),
|
||||
w1_lora_a_stacked.stride(3),
|
||||
w1_lora_a_stacked.stride(2),
|
||||
a_intermediate_cache1.stride(2),
|
||||
a_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
**config,
|
||||
)
|
||||
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
N = w1_output_dim_size
|
||||
|
||||
# a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
# M, -1, a_intermediate_cache1.shape[3]
|
||||
# )
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
lora_b_stacked[0].shape[0],
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
b_ptr,
|
||||
b_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
a_intermediate_cache1.stride(0),
|
||||
a_intermediate_cache1.stride(1),
|
||||
w1_lora_b_stacked.stride(0),
|
||||
w1_lora_b_stacked.stride(1),
|
||||
w1_lora_b_stacked.stride(3),
|
||||
w1_lora_b_stacked.stride(2),
|
||||
b_intermediate_cache1.stride(2),
|
||||
b_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
**config,
|
||||
)
|
||||
for i in range(num_slices):
|
||||
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
|
||||
|
||||
|
||||
def _fused_moe_lora_fake(
|
||||
output: torch.Tensor,
|
||||
qcurr_hidden_states: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora",
|
||||
op_func=_fused_moe_lora,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fused_moe_lora_fake,
|
||||
)
|
||||
fused_moe_lora = torch.ops.vllm.fused_moe_lora
|
||||
|
||||
except AttributeError:
|
||||
fused_moe_lora = _fused_moe_lora
|
||||
@ -448,3 +448,42 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
def add_lora_fused_moe(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of
|
||||
Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
# TODO: implement it based on torch ops
|
||||
raise NotImplementedError
|
||||
|
||||
@ -12,10 +12,18 @@ from typing import final
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops import (
|
||||
LoRAKernelMeta,
|
||||
fused_moe_lora,
|
||||
lora_expand,
|
||||
lora_shrink,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
@ -289,3 +297,91 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
add_inputs=True,
|
||||
)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be set default to -1 to prevent a blank block
|
||||
expert_ids = torch.empty(
|
||||
(max_loras * max_num_m_blocks,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
num_tokens_post_pad = torch.empty(
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
(token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
def add_lora_fused_moe(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.lora.layers import (
|
||||
BaseLayerWithLoRA,
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
FusedMoEWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
@ -35,7 +36,9 @@ from vllm.lora.layers import (
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
FusedMoEWithLoRA,
|
||||
}
|
||||
|
||||
|
||||
def is_moe_model(model: nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers and warns the user."""
|
||||
if any(isinstance(module, FusedMoE) for module in model.modules()):
|
||||
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def from_layer(
|
||||
layer: nn.Module,
|
||||
max_loras: int,
|
||||
@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
|
||||
if isinstance(module, (LinearBase,)):
|
||||
supported_lora_modules.add(name.split(".")[-1])
|
||||
|
||||
if isinstance(module, (FusedMoE,)):
|
||||
supported_lora_modules.add(name.split(".")[-1])
|
||||
|
||||
return list(supported_lora_modules)
|
||||
|
||||
|
||||
@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str:
|
||||
return lora_path
|
||||
|
||||
return local_snapshot_path
|
||||
|
||||
|
||||
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
|
||||
if is_moe_model(model):
|
||||
if moe_packed_mapping := get_moe_expert_mapping(model):
|
||||
# This method generates and returns a dictionary mapping packed module
|
||||
# names to lists of their corresponding submodule names. It includes
|
||||
# both static mappings and dynamic mappings for expert layers, where
|
||||
# the expert indices are expanded based on the configured number
|
||||
# of routed experts.
|
||||
packed_modules_mapping = get_packed_modules_mapping(model)
|
||||
|
||||
packed_modules_mapping["experts"] = [
|
||||
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
|
||||
]
|
||||
|
||||
return packed_modules_mapping
|
||||
else:
|
||||
raise AttributeError(
|
||||
"To support LoRA for MoE model, "
|
||||
"'get_expert_mapping' must be implemented"
|
||||
)
|
||||
else:
|
||||
return get_packed_modules_mapping(model)
|
||||
|
||||
@ -94,7 +94,8 @@ class WorkerLoRAManager:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
if module == "experts":
|
||||
expected_lora_modules.append(module)
|
||||
expected_lora_modules = list(set(expected_lora_modules))
|
||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused MoE utilities for GPTQ."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
def default_activation_func(
|
||||
activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
|
||||
def _fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -36,12 +56,15 @@ def _fused_marlin_moe(
|
||||
num_topk: int,
|
||||
quant_type: ScalarType,
|
||||
apply_router_weight_on_input: bool,
|
||||
activation: str,
|
||||
expert_map: torch.Tensor | None,
|
||||
block_size_m: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
] = default_activation_func,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
@ -118,20 +141,9 @@ def _fused_marlin_moe(
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
activation_func(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = intermediate_cache3
|
||||
@ -185,7 +197,11 @@ def fused_marlin_moe(
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: str | None = "silu",
|
||||
activation: str = "silu",
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
] = default_activation_func,
|
||||
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
@ -290,12 +306,13 @@ def fused_marlin_moe(
|
||||
num_topk=topk,
|
||||
quant_type=quant_type,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
block_size_m=block_size_m,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
activation=activation,
|
||||
activation_func=activation_func,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
@ -317,7 +334,10 @@ def fused_marlin_moe(
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
if moe_sum is None:
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
else:
|
||||
return moe_sum(moe_output, output)
|
||||
|
||||
|
||||
def batched_fused_marlin_moe(
|
||||
@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
activation_func=self.activation,
|
||||
moe_sum=self.moe_sum,
|
||||
expert_map=expert_map,
|
||||
output=output,
|
||||
# Workspaces are swapped in workspace_shapes() to account for proper
|
||||
@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
intermediate_cache2=workspace13,
|
||||
)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
ops.moe_sum(input, output)
|
||||
|
||||
|
||||
def modular_marlin_fused_moe(
|
||||
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
MarlinExperts(quant_config),
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
def __init__(
|
||||
|
||||
@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
B_bias=self.w2_bias,
|
||||
)
|
||||
|
||||
ops.moe_sum(intermediate_cache3, output)
|
||||
# separate function is required for MoE + LoRA
|
||||
self.moe_sum(intermediate_cache3, output)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
ops.moe_sum(input, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(quant_config),
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
|
||||
@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsPP
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@ -627,7 +627,7 @@ class GptOssModel(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
|
||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, weight scales, activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_local_experts,
|
||||
num_redundant_experts=0,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
@ -349,8 +349,6 @@ class OlmoeModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -433,17 +431,13 @@ class OlmoeModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user