From 5f6cbf60d667408ef1bf3871c0b7563effa6094a Mon Sep 17 00:00:00 2001 From: Chen Wu <72850361+CNTRYROA@users.noreply.github.com> Date: Tue, 21 Oct 2025 11:01:37 +0800 Subject: [PATCH] [Feature][Kernel]FusedMoE LoRA (#21229) Signed-off-by: wuchen Signed-off-by: banjuede Signed-off-by: Chen Wu Signed-off-by: Danielle Robinson Signed-off-by: Jee Jee Li Signed-off-by: bk-201 Co-authored-by: wuchen Co-authored-by: Nathan Van Gheem Co-authored-by: banjuede Co-authored-by: Danielle Robinson Co-authored-by: Jee Jee Li Co-authored-by: bk-201 --- .buildkite/test-pipeline.yaml | 8 +- CMakeLists.txt | 1 + csrc/moe/moe_lora_align_sum_kernels.cu | 173 ++++++++ csrc/moe/moe_ops.h | 7 + csrc/moe/torch_bindings.cpp | 12 + tests/lora/conftest.py | 20 + tests/lora/test_deepseekv2_tp.py | 97 +++++ tests/lora/test_fused_moe_lora_kernel.py | 287 ++++++++++++ tests/lora/test_gptoss.py | 52 +++ tests/lora/test_moe_lora_align_sum.py | 90 ++++ tests/lora/test_olmoe_tp.py | 109 +++++ tests/lora/test_qwen3moe_tp.py | 111 +++++ vllm/_custom_ops.py | 22 + vllm/lora/layers/__init__.py | 2 + vllm/lora/layers/fused_moe.py | 410 ++++++++++++++++++ vllm/lora/models.py | 84 +++- vllm/lora/ops/triton_ops/__init__.py | 2 + vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 350 +++++++++++++++ vllm/lora/punica_wrapper/punica_base.py | 39 ++ vllm/lora/punica_wrapper/punica_gpu.py | 100 ++++- vllm/lora/utils.py | 39 ++ vllm/lora/worker_manager.py | 3 +- .../layers/fused_moe/fused_marlin_moe.py | 71 ++- .../layers/fused_moe/fused_moe.py | 9 +- .../layers/fused_moe/modular_kernel.py | 3 + vllm/model_executor/models/deepseek_v2.py | 11 + vllm/model_executor/models/gpt_oss.py | 15 +- vllm/model_executor/models/olmoe.py | 12 +- 28 files changed, 2084 insertions(+), 55 deletions(-) create mode 100644 csrc/moe/moe_lora_align_sum_kernels.cu create mode 100644 tests/lora/test_deepseekv2_tp.py create mode 100644 tests/lora/test_fused_moe_lora_kernel.py create mode 100644 tests/lora/test_gptoss.py create mode 100644 tests/lora/test_moe_lora_align_sum.py create mode 100644 tests/lora/test_olmoe_tp.py create mode 100644 tests/lora/test_qwen3moe_tp.py create mode 100644 vllm/lora/layers/fused_moe.py create mode 100644 vllm/lora/ops/triton_ops/fused_moe_lora_op.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3d8bbed56bd50..984e2108f88e6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -384,7 +384,12 @@ steps: --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ --ignore=lora/test_chatglm3_tp.py \ --ignore=lora/test_llama_tp.py \ - --ignore=lora/test_llm_with_multi_loras.py + --ignore=lora/test_llm_with_multi_loras.py \ + --ignore=lora/test_olmoe_tp.py \ + --ignore=lora/test_deepseekv2_tp.py \ + --ignore=lora/test_gptoss.py \ + --ignore=lora/test_qwen3moe_tp.py + parallelism: 4 - label: PyTorch Compilation Unit Tests # 15min @@ -1065,6 +1070,7 @@ steps: - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llm_with_multi_loras.py + - pytest -v -s -x lora/test_olmoe_tp.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/CMakeLists.txt b/CMakeLists.txt index 005590445361a..46630af89f099 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/moe_lora_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu new file mode 100644 index 0000000000000..1d25844bd5263 --- /dev/null +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../cuda_compat.h" +#include "../dispatch_utils.h" +#include "core/math.hpp" + +namespace { + +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + return row * total_col + col; +} + +} // namespace + +// TODO: Refactor common parts with moe_align_sum_kernels +template +__global__ void moe_lora_align_sum_kernel( + scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, + int64_t block_size, int num_experts, int max_loras, size_t numel, + int max_num_tokens_padded, int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int topk_num, int32_t* total_tokens_post_pad) { + const size_t tokens_per_thread = div_ceil(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + int lora_id = blockIdx.x; + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); + + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel; + } + + // Initialize expert_ids with -1 + for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) { + expert_ids[lora_id * max_num_m_blocks + it] = -1; + } + + // Initialize total_tokens_post_pad with 0 + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + } + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int mask = token_lora_mapping[i / topk_num] == lora_id; + int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]); + tokens_cnts[idx] += mask; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + total_tokens_post_pad[lora_id] = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] = + threadIdx.x; + } + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + + int mask = (int)token_lora_mapping[i / topk_num] == lora_id; + atomicAdd( + &sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)], + (i - numel) * mask); + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask; + } +} + +void moe_lora_align_block_size(torch::Tensor topk_ids, + torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, + int64_t max_loras, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad) { + const int topk_num = topk_ids.size(1); + + int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1); + + TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); + max_num_tokens_padded = round_to_next_multiple_of( + max_num_tokens_padded, static_cast(block_size)); + int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE, + TORCH_CHECK(num_thread <= 1024, + "num_thread must be less than 1024, " + "and fallback is not implemented yet."); + const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, + "Shared memory usage exceeds device limit, and global memory " + "fallback is not implemented yet."); + } + + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + dim3 blockDim(num_thread); + auto kernel = moe_lora_align_sum_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, num_experts, + max_loras, topk_ids.numel(), max_num_tokens_padded, + max_num_m_blocks, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr()); + }); +} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 2a170249b9177..45dd1824ded47 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -20,6 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad); +void moe_lora_align_block_size(torch::Tensor topk_ids, + torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, + int64_t max_loras, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8377575ea19ff..f110683af72d3 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -33,6 +33,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("batched_moe_align_block_size", torch::kCUDA, &batched_moe_align_block_size); + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_lora_align_block_size(Tensor topk_ids," + " Tensor token_lora_mapping," + " int num_experts," + " int block_size, int max_loras, " + " Tensor !sorted_token_ids," + " Tensor !experts_ids," + " Tensor !num_tokens_post_pad) -> () "); + m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); + #ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index f805a74a4dba8..2a688216f25ec 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -230,6 +230,26 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def deepseekv2_lora_files(): + return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA") + + +@pytest.fixture(scope="session") +def gptoss20b_lora_files(): + return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter") + + +@pytest.fixture(scope="session") +def qwen3moe_lora_files(): + return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider") + + +@pytest.fixture(scope="session") +def olmoe_lora_files(): + return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider") + + @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_deepseekv2_tp.py b/tests/lora/test_deepseekv2_tp.py new file mode 100644 index 0000000000000..98b7e6333f300 --- /dev/null +++ b/tests/lora/test_deepseekv2_tp.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat" + +PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501 + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int): + prompts = [ + PROMPT_TEMPLATE.format(context="Who are you?"), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + # return generated_texts + expected_lora_output = [ + "I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501 + ] + for i in range(len(expected_lora_output)): + assert generated_texts[i].startswith(expected_lora_output[i]) + + +def test_deepseekv2_lora(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + generate_and_test(llm, deepseekv2_lora_files, 1) + + +def test_deepseekv2(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + ) + generate_and_test(llm, deepseekv2_lora_files, 1) + + +@multi_gpu_test(num_gpus=2) +def test_deepseekv2_tp2(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + tensor_parallel_size=2, + ) + generate_and_test(llm, deepseekv2_lora_files, 2) + + +@multi_gpu_test(num_gpus=4) +def test_deepseekv2_tp4(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + tensor_parallel_size=4, + ) + generate_and_test(llm, deepseekv2_lora_files, 2) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py new file mode 100644 index 0000000000000..052e52c7bc1ba --- /dev/null +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.lora.ops.triton_ops import fused_moe_lora +from vllm.platforms import current_platform + + +@pytest.fixture(autouse=True) +def reset_device(reset_default_device): + pass + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): + """ + Split `num_tokens` into `num_sequences` sequences. + Each sequence randomly selects 1 LoRA index from [0, max_loras), + and all tokens in that sequence are assigned this LoRA index. + + Args: + num_tokens (int): Total number of tokens. + num_sequences (int): Number of sequences to split the tokens into. + max_loras (int): Total number of available LoRA modules. + + Returns: + torch.Tensor: 1D tensor of shape [num_tokens], where each value + is the LoRA index assigned to that token. + """ + assert num_sequences > 0 and max_loras > 0 + assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" + + # Compute token distribution per sequence (distribute remainder evenly) + tokens_per_seq = num_tokens // num_sequences + remainder = num_tokens % num_sequences + + token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + + start = 0 + for seq_idx in range(num_sequences): + # Determine the token range for this sequence + end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) + + # Randomly select one LoRA ID for this sequence + lora_id = random.randint(0, max_loras - 1) + + # Assign the same LoRA ID to all tokens in this sequence + token_lora_mapping[start:end] = lora_id + + start = end + + return token_lora_mapping + + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): + """ + For each token, randomly select `top_k_num` distinct experts out of `num_experts`, + and assign normalized random weights that sum to 1. + + Args: + num_tokens (int): Total number of tokens. + num_experts (int): Total number of available experts. + top_k_num (int): Number of experts to select per token. + + Returns: + expert_indices (torch.Tensor): shape [num_tokens, top_k_num], + expert index for each token. + expert_weights (torch.Tensor): shape [num_tokens, top_k_num], + normalized weights (sum = 1 per row). + """ + assert top_k_num <= num_experts, "top_k_num must be <= num_experts" + + # Randomly select top_k_num distinct experts for each token + expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) + for i in range(num_tokens): + # Randomly choose unique expert indices + selected = torch.randperm(num_experts)[:top_k_num] + expert_indices[i] = selected + + # Generate random weights and normalize along dim=1 + expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) + + return expert_indices, expert_weights + + +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num + ) + token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + return topk_ids, topk_weights, token_lora_mapping + + +def use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, +): + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + ) + expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + + # call kernel + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + ) + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + + mul_routed_weight = False + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + + fused_moe_lora( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + mul_routed_weight, + ) + + return output + + +def use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, +): + outputs = [] + for i in range(hidden_states.shape[0]): + lora_idx = token_lora_mapping[i] + expert_ids = topk_ids[i] + lora_a = lora_a_stacked[0][lora_idx][expert_ids] + lora_b = lora_b_stacked[0][lora_idx][expert_ids] + tensors = [ + hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) + ] + outputs.append(torch.stack(tensors, dim=0)) + return torch.stack(outputs, dim=0) + + +@pytest.mark.parametrize("num_tokens", [100]) +@pytest.mark.parametrize("top_k_num", [6, 12]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("max_loras", [4, 6, 16]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) +@pytest.mark.parametrize("block_size", [16]) +def test_fused_moe_lora_kernel( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, +): + torch.set_default_device("cuda:0") + current_platform.seed_everything(42) + # the number of randomly generated sentences. + num_sequences = 10 + # generate data + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + # init lora weights + lora_a_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + K, + ), + dtype=torch.bfloat16, + ) + ] + lora_b_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + N, + max_lora_rank, + ), + dtype=torch.bfloat16, + ) + ] + hidden_states = torch.rand( + ( + num_tokens, + K, + ), + dtype=torch.bfloat16, + ) + + # fused_moe_lora_kernel output + output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + # pytorch output + output2 = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + ) + + torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) diff --git a/tests/lora/test_gptoss.py b/tests/lora/test_gptoss.py new file mode 100644 index 0000000000000..cdd0304afa70d --- /dev/null +++ b/tests/lora/test_gptoss.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "openai/gpt-oss-20b" + +PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501 + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: + prompts = [ + PROMPT_TEMPLATE.format(context="Who are you?"), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +# FIXME: Load gpt-oss adapter +def test_gptoss20b_lora(gptoss20b_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_loras=1, + trust_remote_code=True, + ) + + expected_lora_output = [ + "I am an AI language model developed by OpenAI. " + "I am here to help you with any questions or " + "tasks you may have." + ] + + output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1) + print(output1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py new file mode 100644 index 0000000000000..e65dd40bdeb74 --- /dev/null +++ b/tests/lora/test_moe_lora_align_sum.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm import _custom_ops as ops + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def sample_data(num_experts, max_loras, num_tokens, topk_num): + topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32) + token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32) + + for i in range(num_tokens): + pool = list(range(num_experts)) + random.shuffle(pool) + for j in range(topk_num): + topk_ids[i, j] = pool[j] + token_lora_mapping[i] = random.randint(0, max_loras - 1) + + return topk_ids.to("cuda"), token_lora_mapping.to("cuda") + + +@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 +@pytest.mark.parametrize("topk_num", [6]) +@pytest.mark.parametrize("num_experts", [64, 128]) +@pytest.mark.parametrize("max_loras", [2, 32]) +@pytest.mark.parametrize("block_size", [16]) +def test_moe_lora_align_block_size( + num_tokens, topk_num, num_experts, max_loras, block_size +): + # sample data + random.seed(1) + topk_ids, token_lora_mapping = sample_data( + num_experts, max_loras, num_tokens, topk_num + ) + + # compute paddings + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.full( + (max_loras * max_num_tokens_padded,), + topk_ids.numel(), + dtype=torch.int32, + device="cuda", + ) + expert_ids = torch.full( + (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") + + # call kernel + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + ) + + # verify values + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size) + + for lora_idx in range(max_loras): + for token_idx in range(sorted_token_ids.size(1)): + block = sorted_token_ids[lora_idx][token_idx] + indices = block[block != topk_ids.numel()] + if indices.numel() > 0: + expert_id = expert_ids[lora_idx][token_idx] + assert torch.all(topk_ids.view(-1)[indices] == expert_id) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py new file mode 100644 index 0000000000000..b954e0776ca4a --- /dev/null +++ b/tests/lora/test_olmoe_tp.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. + + +###Input: +{context} + +###Response:""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM candidate", + "SELECT count(*) FROM candidate", + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 +] + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context="Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context="Return the poll resource associated with the most candidates." + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + + +def test_olmoe_lora(olmoe_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=2) +def test_olmoe_lora_tp2(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=2, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=4) +def test_olmoe_lora_tp4(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=4, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) diff --git a/tests/lora/test_qwen3moe_tp.py b/tests/lora/test_qwen3moe_tp.py new file mode 100644 index 0000000000000..de2b040907f98 --- /dev/null +++ b/tests/lora/test_qwen3moe_tp.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "Qwen/Qwen3-30B-A3B" + +PROMPT_TEMPLATE = """<|im_start|>user +I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. + + +###Input: +{context} + +###Response:<|im_end|> +<|im_start|>assistant""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "\n\n\n\nSELECT count(*) FROM candidate", + "\n\n\n\nSELECT count(*) FROM candidate", + "\n\n\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "\n\n\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 +] + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context="Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context="Return the poll resource associated with the most candidates." + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + + +def test_qwen3moe_lora(qwen3moe_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=2) +def test_qwen3moe_lora_tp2(qwen3moe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=2, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=4) +def test_qwen3moe_lora_tp4(qwen3moe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=4, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0618451c199ac..7efd6aa446e01 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1795,6 +1795,28 @@ def batched_moe_align_block_size( ) +def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + num_experts: int, + block_size: int, + max_loras: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) + + def moe_wna16_gemm( input: torch.Tensor, output: torch.Tensor, diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 4915ef85f4f73..8a4f5ff175d4f 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -11,6 +11,7 @@ from vllm.lora.layers.column_parallel_linear import ( QKVParallelLinearWithLoRA, QKVParallelLinearWithShardedLoRA, ) +from vllm.lora.layers.fused_moe import FusedMoEWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -36,4 +37,5 @@ __all__ = [ "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", "LoRAMapping", + "FusedMoEWithLoRA", ] diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py new file mode 100644 index 0000000000000..1928fc2e3f936 --- /dev/null +++ b/vllm/lora/layers/fused_moe.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm import envs +from vllm.config.lora import LoRAConfig +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + _get_config_dtype_str, + mxfp4_w4a16_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + modular_marlin_fused_moe, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + modular_triton_fused_moe, + try_get_optimal_moe_config, +) +from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config + + +class FusedMoEWithLoRA(BaseLayerWithLoRA): + def __init__(self, base_layer: FusedMoE) -> None: + super().__init__() + self.base_layer = base_layer + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.device = base_layer.w2_weight.device + self._inject_lora_into_fused_moe() + + def _inject_lora_into_fused_moe(self): + moe_state_dict = {} + top_k = self.base_layer.top_k + + if self.base_layer.quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + elif not isinstance(self.base_layer.quant_config, Mxfp4Config): + quant_config = self.base_layer.quant_config + else: + quant_config = mxfp4_w4a16_moe_quant_config( + w1_bias=self.base_layer.w13_bias, + w2_bias=self.base_layer.w2_bias, + w1_scale=self.base_layer.w13_weight_scale, + w2_scale=self.base_layer.w2_weight_scale, + ) + + m_fused_moe_fn = ( + modular_triton_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts + ) + if not quant_config.use_mxfp4_w4a16 + else modular_marlin_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts + ) + ) + + def fwd_decorator(layer, func): + def wrapper(*args, **kwargs): + moe_state_dict["hidden_states"] = kwargs["hidden_states"] + moe_state_dict["topk_ids"] = kwargs["topk_ids"] + moe_state_dict["topk_weights"] = kwargs["topk_weights"] + moe_state_dict["global_num_experts"] = kwargs["global_num_experts"] + moe_state_dict["expert_map"] = kwargs["expert_map"] + moe_state_dict["apply_router_weight_on_input"] = kwargs[ + "apply_router_weight_on_input" + ] + moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0] + result = func(*args, **kwargs) + return result + + return wrapper + + def act_decorator(layer, func): + def wrapper(*args, **kwargs): + _, output, input = args + + hidden_states = moe_state_dict["hidden_states"] + topk_weights = moe_state_dict["topk_weights"] + curr_topk_ids = moe_state_dict["topk_ids"] + global_num_experts = moe_state_dict["global_num_experts"] + expert_map = moe_state_dict["expert_map"] + max_loras = moe_state_dict["max_loras"] + + config_dtype = _get_config_dtype_str( + dtype=hidden_states.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_tokens = hidden_states.size(0) + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + + config = get_config_func(M) + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + ) = self.punica_wrapper.moe_lora_align_block_size( + curr_topk_ids, + num_tokens, + config["BLOCK_SIZE_M"], + global_num_experts, + max_loras, + expert_map, + ) + + moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora + moe_state_dict["expert_ids_lora"] = expert_ids_lora + moe_state_dict["num_tokens_post_padded_lora"] = ( + num_tokens_post_padded_lora + ) + + w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked] + w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked] + max_lora_rank = self.w1_lora_a_stacked.shape[-2] + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + + self.punica_wrapper.add_lora_fused_moe( + input.view(-1, top_k, input.shape[-1]), + hidden_states, + w13_lora_a_stacked, + w13_lora_b_stacked, + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + config, + ) + + result = func(*args, **kwargs) + + moe_state_dict["intermediate_cache2"] = output + return result + + return wrapper + + def moe_sum_decorator(layer, func): + def wrapper(*args, **kwargs): + hidden_states = moe_state_dict["hidden_states"] + topk_weights = moe_state_dict["topk_weights"] + max_loras = moe_state_dict["max_loras"] + + config_dtype = _get_config_dtype_str( + dtype=hidden_states.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_tokens = hidden_states.size(0) + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + + config = get_config_func(M) + + sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] + expert_ids_lora = moe_state_dict["expert_ids_lora"] + num_tokens_post_padded_lora = moe_state_dict[ + "num_tokens_post_padded_lora" + ] + + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + intermediate_cache2 = moe_state_dict["intermediate_cache2"] + intermediate_cache3 = args[0] + max_lora_rank = self.w1_lora_a_stacked.shape[-2] + self.punica_wrapper.add_lora_fused_moe( + intermediate_cache3, + intermediate_cache2, + [self.w2_lora_a_stacked], + [self.w2_lora_b_stacked], + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + config, + True, + ) + + result = func(*args, **kwargs) + return result + + return wrapper + + fused_experts = m_fused_moe_fn.fused_experts + + m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward) + fused_experts.activation = act_decorator( + self.base_layer, fused_experts.activation + ) + fused_experts.moe_sum = moe_sum_decorator( + self.base_layer, fused_experts.moe_sum + ) + + self.base_layer.quant_method.old_fused_experts = ( + self.base_layer.quant_method.fused_experts + ) + self.base_layer.quant_method.fused_experts = m_fused_moe_fn + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: PretrainedConfig | None = None, + ) -> None: + """Initializes lora matrices.""" + + assert not self.base_layer.use_ep, ( + "EP support for Fused MoE LoRA is not implemented yet." + ) + + self.w1_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w1_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.intermediate_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + self.w2_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.intermediate_size_per_partition, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w2_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.hidden_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + self.w3_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w3_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.intermediate_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked + self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked + self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked + self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked + self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked + self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked + # They will be used by 'LoRALayerWeights.create_dummy_lora_weights' + # to create a dummy LoRA weights. + self.lora_a_stacked = [] + self.lora_b_stacked = [] + for lora_id in range(max_loras): + for experts_id in range(self.base_layer.global_num_experts): + # gate_proj,down_proj,up_proj + self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id]) + + self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id]) + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + self.w1_lora_a_stacked[index] = 0 + self.w1_lora_b_stacked[index] = 0 + self.w3_lora_a_stacked[index] = 0 + self.w3_lora_b_stacked[index] = 0 + self.w2_lora_a_stacked[index] = 0 + self.w2_lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: torch.Tensor | None, + bias: torch.Tensor | None = None, + ): + """Overwrites lora tensors at index.""" + for eid in range(len(lora_a) // 3): + w1_lora_a = lora_a[eid * 3] + w2_lora_a = lora_a[eid * 3 + 1] + w3_lora_a = lora_a[eid * 3 + 2] + w1_lora_b = lora_b[eid * 3] + w2_lora_b = lora_b[eid * 3 + 1] + w3_lora_b = lora_b[eid * 3 + 2] + + if self.tp_size > 1: + shard_size = self.base_layer.intermediate_size_per_partition + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + + w1_lora_b = w1_lora_b[start_idx:end_idx, :] + w3_lora_b = w3_lora_b[start_idx:end_idx, :] + w2_lora_a = w2_lora_a[:, start_idx:end_idx] + + self.w1_lora_a_stacked[ + index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] + ].copy_(w1_lora_a, non_blocking=True) + + self.w3_lora_a_stacked[ + index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1] + ].copy_(w3_lora_a, non_blocking=True) + + self.w2_lora_b_stacked[ + index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] + ].copy_(w2_lora_b, non_blocking=True) + + self.w1_lora_b_stacked[ + index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] + ].copy_(w1_lora_b, non_blocking=True) + self.w3_lora_b_stacked[ + index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] + ].copy_(w3_lora_b, non_blocking=True) + self.w2_lora_a_stacked[ + index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] + ].copy_(w2_lora_a, non_blocking=True) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None, + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + # return type(source_layer) is FusedMoE + return isinstance(source_layer, FusedMoE) + + def forward(self, *args, **kwargs): + return self.base_layer.forward(*args, **kwargs) + + def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs): + return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs) + + @property + def _shared_experts(self): + return self.base_layer._shared_experts + + @property + def quant_method(self): + return self.base_layer.quant_method diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 4840af7c7451b..27d52f17816aa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -13,7 +13,7 @@ from torch import nn from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping +from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper @@ -23,15 +23,14 @@ from vllm.lora.utils import ( get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, + process_packed_modules_mapping, replace_submodule, ) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper -from vllm.model_executor.utils import get_packed_modules_mapping from vllm.utils import is_pin_memory_available from vllm.utils.cache import LRUCache @@ -60,18 +59,6 @@ def get_lora_id(): return _GLOBAL_LORA_ID -def is_moe_model(model: nn.Module) -> bool: - """Checks if the model contains FusedMoE layers and warns the user.""" - if any(isinstance(module, FusedMoE) for module in model.modules()): - logger.warning_once( - "For MoE models, vLLM currently does not support fused MoE LoRA " - "inference. Please ensure that the loaded LoRA model does not " - "contain expert weights." - ) - return True - return False - - class LoRAModel: """A LoRA fine-tuned model.""" @@ -229,9 +216,19 @@ class LoRAModel: def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) - part_name = module_name.split(".")[-1] - if part_name not in expected_lora_modules: + # Handle FSDP file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + if "base_layer" in lora_module: + continue + # Case for expert lora weights + if ".experts" in module_name: + if not any( + module_name.endswith(ele) for ele in expected_lora_modules + ): + unexpected_modules.append(module_name) + elif module_name.split(".")[-1] not in expected_lora_modules: unexpected_modules.append(module_name) + if unexpected_modules: raise ValueError( f"While loading {lora_dir}, expected" @@ -371,7 +368,7 @@ class LoRAModelManager: assert self.supported_lora_modules, "No supported LoRA modules found in" f" {self.model.__class__.__name__}." - self.packed_modules_mapping = get_packed_modules_mapping(self.model) + self.packed_modules_mapping = process_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) @@ -380,7 +377,6 @@ class LoRAModelManager: and hasattr(self.model, "get_mm_mapping") ) self.is_pooling_model = is_pooling_model(self.model) - self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} # Dict instead of a set for compatibility with LRUCache. @@ -431,6 +427,50 @@ class LoRAModelManager: module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: module_lora.optimize() + # Note (gnovack) - If MOE lora weights are not split into + # num_experts chunks, we split them here + if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor( + module_lora.lora_a + ): + # Handle FSDP file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + gate_up_proj_lora = self._get_lora_layer_weights( + lora_model, module_name + ".base_layer" + ) + + assert gate_up_proj_lora is not None + assert module_lora is not None + + down_proj_lora = module_lora + num_experts = module_lora.lora_a.shape[0] // module_lora.rank + + gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + + gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk( + num_experts, dim=-1 + ) + up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk( + num_experts, dim=-1 + ) + + down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0) + down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1) + + lora_a = [] + lora_b = [] + for i in range(num_experts): + lora_a.append(gate_proj_a[i]) + lora_a.append(down_proj_a[i]) + lora_a.append(up_proj_a[i]) + + lora_b.append(gate_proj_b[i]) + lora_b.append(down_proj_b[i]) + lora_b.append(up_proj_b[i]) + + module_lora.lora_a = lora_a + module_lora.lora_b = lora_b + module.set_lora( index, module_lora.lora_a, @@ -486,6 +526,7 @@ class LoRAModelManager: for module_name, module in self.model.named_modules(remove_duplicate=False): if isinstance(module, PPMissingLayer): continue + if not self._match_target_modules(module_name): continue # A temporary approach for multimodal models to support LoRA @@ -549,7 +590,10 @@ class LoRAModelManager: new_module.set_mapping(self.punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): - assert isinstance(module, BaseLayerWithLoRA) + assert isinstance(module, BaseLayerWithLoRA), ( + f"Module {module_name} must be a BaseLayerWithLoRA instance," + ) + f" got {type(module)}" self.modules[module_name] = module def create_dummy_lora( diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 805de4b6f6570..436ea4ed00c82 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink @@ -9,4 +10,5 @@ __all__ = [ "lora_expand", "lora_shrink", "LoRAKernelMeta", + "fused_moe_lora", ] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py new file mode 100644 index 0000000000000..94935d8dfe866 --- /dev/null +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +from vllm.utils.torch_utils import direct_register_custom_op + +_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} + + +def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): + """ + `_LORA_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + + if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None: + return ptr_tensor + + tensor_ptrs = [] + for lora_weight in lora_weights: + tensor_ptrs.append(lora_weight.data_ptr()) + ptr_tensor = torch.tensor(tensor_ptrs, device=device) + + _LORA_PTR_DICT[key] = ptr_tensor + return _LORA_PTR_DICT.get(key) + + +@triton.jit +def _fused_moe_lora_kernel( + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + num_experts, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_bl, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_tl, + stride_el, + # Meta-parameters + num_slice_a: tl.constexpr, + num_slice_c: tl.constexpr, + slice_a_size: tl.constexpr, + slice_c_size: tl.constexpr, + top_k: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + max_loras = tl.num_programs(axis=2) + + # calculate pid_m,pid_n + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + # get the expert_id to process curr shard + ind = lora_idx * stride_el + pid_m + expert_id = tl.load(expert_ids_ptr + ind) + if expert_id == -1: + return + + # get a_ptr,b_ptr,c_ptr + cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size + cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16)) + cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + token_ind = stride_tl * lora_idx + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0 + ) + token_mask = offs_token < num_valid_tokens + + # get a_ptrs,b_ptrs + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + cur_b_ptr + + lora_idx * stride_bl + + expert_id * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + + # accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(tl.bfloat16) + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + mul_routed_weight: bool = False, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + + config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + } + + w1_lora_a_stacked = lora_a_stacked[0] + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + lora_intermediate_cache1 = torch.zeros( + (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), + dtype=torch.bfloat16, + device=device, + ) + + # slices + a_intermediate_size = num_slices * M * top_k_num * max_lora_rank + a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view( + num_slices, M, top_k_num, max_lora_rank + ) + b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view( + num_slices, M, top_k_num, w1_output_dim_size + ) + + b_ptr = _get_ptr(lora_a_stacked, device) + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_a_stacked), + lora_a_stacked[0].shape[0], + ) + + _fused_moe_lora_kernel[grid]( + qcurr_hidden_states, + b_ptr, + a_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + w1_lora_a_stacked.stride(0), + w1_lora_a_stacked.stride(1), + w1_lora_a_stacked.stride(3), + w1_lora_a_stacked.stride(2), + a_intermediate_cache1.stride(2), + a_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + num_slice_a=1, + num_slice_c=num_slices, + slice_a_size=qcurr_hidden_states.numel(), + slice_c_size=a_intermediate_cache1.numel() // num_slices, + top_k=1 if mul_routed_weight else top_k_num, + MUL_ROUTED_WEIGHT=False, + **config, + ) + + b_ptr = _get_ptr(lora_b_stacked, device) + K = max_lora_rank + N = w1_output_dim_size + + # a_intermediate_cache1 = a_intermediate_cache1.view( + # M, -1, a_intermediate_cache1.shape[3] + # ) + + a_intermediate_cache1 = a_intermediate_cache1.view( + -1, a_intermediate_cache1.shape[3] + ) + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_b_stacked), + lora_b_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + a_intermediate_cache1, + b_ptr, + b_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + a_intermediate_cache1.stride(0), + a_intermediate_cache1.stride(1), + w1_lora_b_stacked.stride(0), + w1_lora_b_stacked.stride(1), + w1_lora_b_stacked.stride(3), + w1_lora_b_stacked.stride(2), + b_intermediate_cache1.stride(2), + b_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + num_slice_a=num_slices, + num_slice_c=num_slices, + slice_a_size=a_intermediate_cache1.numel() // num_slices, + slice_c_size=b_intermediate_cache1.numel() // num_slices, + top_k=1, + MUL_ROUTED_WEIGHT=mul_routed_weight, + **config, + ) + for i in range(num_slices): + output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] + + +def _fused_moe_lora_fake( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + mul_routed_weight: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="fused_moe_lora", + op_func=_fused_moe_lora, + mutates_args=["output"], + fake_impl=_fused_moe_lora_fake, + ) + fused_moe_lora = torch.ops.vllm.fused_moe_lora + +except AttributeError: + fused_moe_lora = _fused_moe_lora diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 3f3f33baaa793..5b4a18cf4789b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -448,3 +448,42 @@ class PunicaWrapperBase(PunicaWrapperABC): """ # TODO: implement it based on torch ops raise NotImplementedError + + def moe_lora_align_block_size( + self, + topk_ids: torch.Tensor, + num_tokens: int, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + def add_lora_fused_moe( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + config, + mul_routed_weight=False, + ): + """ + Performs a fused forward computation for LoRA of + Mixture-of-Experts (MoE) layer. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index cdb0e67082909..daf89cd97c385 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -12,10 +12,18 @@ from typing import final import torch from vllm.lora.layers import LoRAMapping -from vllm.triton_utils import HAS_TRITON +from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils import round_up if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( + LoRAKernelMeta, + fused_moe_lora, + lora_expand, + lora_shrink, + ) + +from vllm import _custom_ops as ops from .punica_base import PunicaWrapperBase @@ -289,3 +297,91 @@ class PunicaWrapperGPU(PunicaWrapperBase): add_inputs=True, ) y = y.view_as(y_org) + + def moe_lora_align_block_size( + self, + topk_ids: torch.Tensor, + num_tokens: int, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + (token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args( + num_tokens + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + def add_lora_fused_moe( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + config, + mul_routed_weight=False, + ): + """ + Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. + """ + fused_moe_lora( + y, + x, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + mul_routed_weight, + ) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index e61c5ae701233..0f43ff06d8f2b 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -23,6 +23,7 @@ from vllm.lora.layers import ( BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + FusedMoEWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithShardedLoRA, @@ -35,7 +36,9 @@ from vllm.lora.layers import ( RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, ) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping if TYPE_CHECKING: from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -58,9 +61,18 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = { MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA, + FusedMoEWithLoRA, } +def is_moe_model(model: nn.Module) -> bool: + """Checks if the model contains FusedMoE layers and warns the user.""" + if any(isinstance(module, FusedMoE) for module in model.modules()): + logger.info_once("MoE model detected. Using fused MoE LoRA implementation.") + return True + return False + + def from_layer( layer: nn.Module, max_loras: int, @@ -205,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: if isinstance(module, (LinearBase,)): supported_lora_modules.add(name.split(".")[-1]) + if isinstance(module, (FusedMoE,)): + supported_lora_modules.add(name.split(".")[-1]) + return list(supported_lora_modules) @@ -252,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path return local_snapshot_path + + +def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]: + if is_moe_model(model): + if moe_packed_mapping := get_moe_expert_mapping(model): + # This method generates and returns a dictionary mapping packed module + # names to lists of their corresponding submodule names. It includes + # both static mappings and dynamic mappings for expert layers, where + # the expert indices are expanded based on the configured number + # of routed experts. + packed_modules_mapping = get_packed_modules_mapping(model) + + packed_modules_mapping["experts"] = [ + weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping + ] + + return packed_modules_mapping + else: + raise AttributeError( + "To support LoRA for MoE model, " + "'get_expert_mapping' must be implemented" + ) + else: + return get_packed_modules_mapping(model) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 635685079b2d7..b85151f2c7592 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -94,7 +94,8 @@ class WorkerLoRAManager: expected_lora_modules.extend(packed_modules_mapping[module]) else: expected_lora_modules.append(module) - + if module == "experts": + expected_lora_modules.append(module) expected_lora_modules = list(set(expected_lora_modules)) lora_path = get_adapter_absolute_path(lora_request.lora_path) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e457b729da8c5..3b0df6c416a04 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" +from collections.abc import Callable + import torch import vllm._custom_ops as ops @@ -11,6 +13,9 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, @@ -24,6 +29,21 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.scalar_type import ScalarType, scalar_types +def default_activation_func( + activation: str, output: torch.Tensor, input: torch.Tensor +) -> None: + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(output, input) + else: + raise ValueError( + f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported." + ) + + def _fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -36,12 +56,15 @@ def _fused_marlin_moe( num_topk: int, quant_type: ScalarType, apply_router_weight_on_input: bool, - activation: str, expert_map: torch.Tensor | None, block_size_m: int, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, + activation: str = "silu", + activation_func: Callable[ + [str, torch.Tensor, torch.Tensor], None + ] = default_activation_func, global_scale1: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None, @@ -118,20 +141,9 @@ def _fused_marlin_moe( is_zp_float=False, ) - if activation == "silu": - torch.ops._C.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - else: - raise ValueError( - f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported." - ) + activation_func( + activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) if output is None: output = intermediate_cache3 @@ -185,7 +197,11 @@ def fused_marlin_moe( quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - activation: str | None = "silu", + activation: str = "silu", + activation_func: Callable[ + [str, torch.Tensor, torch.Tensor], None + ] = default_activation_func, + moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, expert_map: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None, @@ -290,12 +306,13 @@ def fused_marlin_moe( num_topk=topk, quant_type=quant_type, apply_router_weight_on_input=apply_router_weight_on_input, - activation=activation, expert_map=expert_map, block_size_m=block_size_m, sorted_token_ids=sorted_token_ids, expert_ids=expert_ids, num_tokens_post_padded=num_tokens_post_padded, + activation=activation, + activation_func=activation_func, global_scale1=global_scale1, global_scale2=global_scale2, g_idx1=g_idx1, @@ -317,7 +334,10 @@ def fused_marlin_moe( else: output = torch.empty_like(hidden_states) - return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) + if moe_sum is None: + return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) + else: + return moe_sum(moe_output, output) def batched_fused_marlin_moe( @@ -600,6 +620,8 @@ class MarlinExperts(MarlinExpertsBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, + activation_func=self.activation, + moe_sum=self.moe_sum, expert_map=expert_map, output=output, # Workspaces are swapped in workspace_shapes() to account for proper @@ -608,6 +630,19 @@ class MarlinExperts(MarlinExpertsBase): intermediate_cache2=workspace13, ) + def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: + ops.moe_sum(input, output) + + +def modular_marlin_fused_moe( + quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + MarlinExperts(quant_config), + shared_experts, + ) + class BatchedMarlinExperts(MarlinExpertsBase): def __init__( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f5760fea6522e..031381332cc9b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2135,13 +2135,18 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): B_bias=self.w2_bias, ) - ops.moe_sum(intermediate_cache3, output) + # separate function is required for MoE + LoRA + self.moe_sum(intermediate_cache3, output) + + def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: + ops.moe_sum(input, output) def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig, + quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), + shared_experts, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 0fa98b1c7f670..8514b63556ae9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -557,6 +557,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): torch.ops._C.silu_and_mul(output, input) elif activation == "gelu": torch.ops._C.gelu_and_mul(output, input) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(output, input) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 58133aa55d596..5827e606b4a5e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1313,6 +1313,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR logits = self.logits_processor(self.lm_head, hidden_states) return logits + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7f4040ca94223..1e32c433cabae 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -32,7 +32,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .interfaces import SupportsEagle3, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -627,7 +627,7 @@ class GptOssModel(nn.Module): ) -class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( @@ -696,6 +696,17 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): logits = self.logits_processor(self.lm_head, hidden_states) return logits + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, weight scales, activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts, + num_redundant_experts=0, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 06307ae22c1b9..7f867244330fa 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -49,7 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -349,8 +349,6 @@ class OlmoeModel(nn.Module): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) @@ -433,17 +431,13 @@ class OlmoeModel(nn.Module): return loaded_params -class OlmoeForCausalLM(nn.Module, SupportsPP): +class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } def __init__(