[LoRA] Optimize 3D MoE logic (#29222)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-24 10:27:23 +08:00 committed by GitHub
parent c309bb5245
commit 1073ba68b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 397 additions and 105 deletions

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm
from vllm.lora.request import LoRARequest
@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
@multi_gpu_test(num_gpus=2)
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
max_num_seqs=16,
tensor_parallel_size=2,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),

View File

@ -11,7 +11,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
@ -38,4 +38,5 @@ __all__ = [
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
"FusedMoE3DWithLoRA",
]

View File

@ -42,8 +42,8 @@ class BaseLayerWithLoRA(nn.Module):
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...

View File

@ -94,13 +94,15 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)

View File

@ -246,8 +246,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
self.reset_lora(index)

View File

@ -42,7 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
self.w13_slices = 2
self._w13_slices = 2
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
@ -160,7 +160,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
op_prefix="w13",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=self.w13_slices,
num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
@ -230,7 +230,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
num_loras=self.max_loras,
@ -258,8 +258,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
(self.w2_lora_a_stacked,),
(self.w2_lora_b_stacked,),
self.w2_lora_a_stacked,
self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
@ -292,22 +292,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.base_layer.quant_method, m_fused_moe_fn
)
def create_lora_weights(
def _create_lora_a_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
assert self.w13_slices == 2
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self.w13_lora_a_stacked = tuple(
):
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
@ -320,10 +310,23 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.w13_slices)
for _ in range(self._w13_slices)
)
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.w13_lora_b_stacked = tuple(
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
@ -334,34 +337,42 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.w13_slices)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
# TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
@ -370,36 +381,43 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
self.w2_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(
self.w2_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self.w13_slices):
for pos in range(self._w13_slices):
self.w13_lora_a_stacked[pos][index] = 0
self.w13_lora_b_stacked[pos][index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[0][index] = 0
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
@ -432,7 +450,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
w2_shard_size = self.w2_lora_b_stacked[index, eid].shape[0]
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
@ -454,26 +472,14 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[
self.w2_lora_b_stacked[0][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return isinstance(source_layer, FusedMoE)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@ -491,3 +497,220 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
def __init__(self, base_layer):
super().__init__(base_layer)
self._w13_slices = 1
def _create_lora_b_weights(self, max_loras, lora_config):
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition * 2,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
if is_interleave:
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
w3_lora_b = w13_lora_b[:, 1::2, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
1, 2
)
else:
slice_size = w13_lora_b.shape[1] // 2
w1_lora_b = w13_lora_b[:, :slice_size, :]
w3_lora_b = w13_lora_b[:, slice_size:, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
assert len(lora_a) == len(lora_b) == 2
self.reset_lora(index)
self.adapter_enabled[index] = 1
num_experts = self.w13_lora_a_stacked[0].shape[1]
w13_lora_a, w2_lora_a = lora_a
w13_lora_b, w2_lora_b = lora_b
# (num_experts,rank,input_size)
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
# (output_size,num_experts,rank)
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
# (num_experts,output_size,rank)
w13_lora_b = w13_lora_b.permute(1, 0, 2)
w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
].copy_(sliced_w13_lora_a, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
].copy_(sliced_w13_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
@property
def w13_input_size(self):
"""
Full size
"""
return self.w13_lora_a_stacked[0].shape[-1]
@property
def w13_output_size(self):
"""
Full size
"""
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
@property
def w2_input_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
@property
def w2_output_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-2]
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1

View File

@ -128,9 +128,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True

View File

@ -77,12 +77,15 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True
)

View File

@ -22,11 +22,13 @@ from vllm.lora.utils import (
from_layer_logits_processor,
get_supported_lora_modules,
is_base_embeddding_weights,
is_moe_model,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
@ -356,7 +358,11 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
self.model, "is_3d_moe_weight"
)
self._create_lora_modules()
self.model.lora_manager = self
def __len__(self) -> int:
@ -400,22 +406,36 @@ class LoRAModelManager:
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
assert gate_up_proj_lora is not None
assert module_lora is not None
down_proj_lora = module_lora
if not module_lora:
module.reset_lora(index)
continue
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
down_proj_lora = module_lora
# FIXME Edge case where LoRA is not added to gate_up_proj
# or down_proj
assert gate_up_proj_lora is not None
assert down_proj_lora is not None
if self._is_3d_moe_model:
module_lora.lora_a = [
gate_up_proj_lora.lora_a,
down_proj_lora.lora_a,
]
module_lora.lora_b = [
gate_up_proj_lora.lora_b,
down_proj_lora.lora_b,
]
else:
# Some 3D MoE models haven't added the `is_3d_moe_weight`
# attribute yet, so fallback here
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
@ -444,14 +464,12 @@ class LoRAModelManager:
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
module.set_lora(
index,
module_lora.lora_a,
module_lora.lora_b,
)
module.set_lora(
index,
module_lora.lora_a,
module_lora.lora_b,
)
else:
module.reset_lora(index)
return True
def _deactivate_adapter(self, lora_id: int):
@ -512,6 +530,13 @@ class LoRAModelManager:
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
if isinstance(module, FusedMoE):
# packed_moduled_lst is used here to just determine whether to
# instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
# difference between these two LoRA layers is whether the
# LoRA weights of w1 and w3 have already been fused on disk.
packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
new_module = replace_submodule(
self.model,
module_name,
@ -560,6 +585,7 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
pass
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
@ -605,6 +631,30 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
# Case for 3D moe model
# w2
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w2_input_size,
module.w2_output_size,
rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
module.w2_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
# w13
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w13_input_size,
module.w13_output_size,
rank
* module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
module.w13_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name + ".base_layer"] = lora
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
@ -614,6 +664,7 @@ class LoRAModelManager:
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
@ -629,7 +680,7 @@ class LoRAModelManager:
)
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):

View File

@ -23,6 +23,7 @@ from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
FusedMoE3DWithLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
@ -62,6 +63,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
FusedMoE3DWithLoRA,
}
@ -288,10 +290,12 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
if not hasattr(model, "is_3d_moe_weight"):
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [
weight_name.rstrip(".")
for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:

View File

@ -656,6 +656,7 @@ class GptOssModel(nn.Module):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(