[LoRA] Continue optimizing MoE LoRA weight loading (#29322)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-27 21:56:28 +08:00 committed by GitHub
parent cf348c8d27
commit 2f5f9acd55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 228 additions and 157 deletions

View File

@ -28,12 +28,13 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = []
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096
@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = []
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",

View File

@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = []
expected_lora_lst: list[str] = []
for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and huggingface id.

View File

@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError

View File

@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(

View File

@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_b[:, start_idx:end_idx, :]
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices):
@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
#
def set_lora(
self,
index: int,
@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
num_experts = self.w13_lora_a_stacked[0].shape[1]
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_a, w2_lora_a, w3_lora_a = lora_a
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
assert (
num_experts
== w1_lora_a.shape[0]
== w2_lora_a.shape[0]
== w3_lora_a.shape[0]
)
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
# w1 lora_a
self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
# w3 lora_a
self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
self.w13_lora_a_stacked[0][
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
].copy_(slliced_w1_lora_a, non_blocking=True)
# w1 lora_b
self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
# w3 lora_b
self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w13_lora_a_stacked[1][
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(slliced_w3_lora_a, non_blocking=True)
self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(slliced_w1_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
self.w13_lora_b_stacked[1][
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
# source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
assert isinstance(model_config, PretrainedConfig)
self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
if is_interleave:
# HACK: Currently, only GPT-OSS is in interleaved order
if self._base_model == "GptOssForCausalLM":
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora(
self,
index: int,
@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
# source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

View File

@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# Special handling for the LogitsProcessor.
return False

View File

@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ReplicatedLinear

View File

@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is RowParallelLinear
@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(

View File

@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding

View File

@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
)
return obj
@classmethod
def pack_moe(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
assert first_lora is not None
rank = first_lora.rank
lora_alpha = first_lora.lora_alpha
assert len(loras) % 3 == 0
w1_lora_a_lst = []
w2_lora_a_lst = []
w3_lora_a_lst = []
w1_lora_b_lst = []
w2_lora_b_lst = []
w3_lora_b_lst = []
# TODO: Consider the case where some experts don't have LoRA added.
for eid in range(len(loras) // 3):
w1_lora = loras[eid * 3]
w2_lora = loras[eid * 3 + 1]
w3_lora = loras[eid * 3 + 2]
assert w1_lora is not None
assert w2_lora is not None
assert w3_lora is not None
w1_lora_a_lst.append(w1_lora.lora_a)
w2_lora_a_lst.append(w2_lora.lora_a)
w3_lora_a_lst.append(w3_lora.lora_a)
w1_lora_b_lst.append(w1_lora.lora_b)
w2_lora_b_lst.append(w2_lora.lora_b)
w3_lora_b_lst.append(w3_lora.lora_b)
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
obj = cls(
module_name,
rank,
[lora_alpha, lora_alpha, lora_alpha],
[w1_lora_a, w2_lora_a, w3_lora_a],
[w1_lora_b, w2_lora_b, w3_lora_b],
)
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):

View File

@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
@ -151,16 +151,13 @@ class LoRAModel:
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: list[str],
expected_lora_modules: set[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
@ -190,10 +187,7 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
# new_embeddings_tensor_path = os.path.join(
# lora_dir, "new_embeddings.safetensors"
# )
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
@ -201,18 +195,19 @@ class LoRAModel:
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Handle FSDP file format where experts.base_layer is the
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
if not any(
module_name.endswith(ele) for ele in expected_lora_modules
):
expert_idx = module_name.find(".experts")
expert_suffix = module_name[expert_idx + 1 :]
if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
@ -358,9 +353,7 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
self.model, "is_3d_moe_weight"
)
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
self._create_lora_modules()
self.model.lora_manager = self
@ -411,7 +404,7 @@ class LoRAModelManager:
continue
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle PEFT file format where experts.base_layer is the
@ -679,7 +672,10 @@ class LoRAModelManager:
"cpu",
)
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
if module.__class__.__name__ == "FusedMoEWithLoRA":
lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
else:
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
@ -739,13 +735,21 @@ class LoRAModelManager:
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras
)
if module_name.endswith(".experts"):
lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
replacement_loras, module_name
)
else:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras
)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
for lora in lora_model.loras.values():
lora.optimize()
def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None:

View File

@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights
input_embedding_subfix = ".embed_tokens.base_layer.weight"
output_embedding_subfix = ".lm_head.base_layer.weight"
return name.endswith(input_embedding_subfix) or name.endswith(
output_embedding_subfix
embedding_suffixes = (
".embed_tokens.base_layer.weight",
".lm_head.base_layer.weight",
)
return name.endswith(embedding_suffixes)
def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str]
load_modules: str | list[str], expected_lora_modules: set[str]
) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
@ -195,8 +194,8 @@ def is_regex_target_modules(
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
def is_subset(sub_list, full_set):
return set(sub_list).issubset(full_set)
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
if not hasattr(model, "is_3d_moe_weight"):
if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [
weight_name.rstrip(".")

View File

@ -88,15 +88,15 @@ class WorkerLoRAManager:
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: list[str] = []
expected_lora_lst: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_lst.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(

View File

@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
is_3d_moe_weight: ClassVar[bool] = False
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {}

View File

@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
):
is_3d_moe_weight: bool = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",