[LoRA] Continue optimizing MoE LoRA weight loading (#29322)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-27 21:56:28 +08:00 committed by GitHub
parent cf348c8d27
commit 2f5f9acd55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 228 additions and 157 deletions

View File

@ -28,12 +28,13 @@ def test_load_checkpoints(
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
if lora_name == "baichuan7B": if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir( peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096 baichuan_lora_files, max_position_embeddings=4096
@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES: for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"model.": "language_model.model.", "model.": "language_model.model.",

View File

@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in LLAMA_LORA_MODULES: for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_name) lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and huggingface id. # lora loading should work for either absolute path and huggingface id.

View File

@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError raise NotImplementedError

View File

@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ColumnParallelLinear or ( return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear type(source_layer) is MergedColumnParallelLinear
@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return ( return (
type(source_layer) is MergedColumnParallelLinear type(source_layer) is MergedColumnParallelLinear
@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1 return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3 return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(
@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(

View File

@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w13_lora_b_stacked[1][lora_id][experts_id] self.w13_lora_b_stacked[1][lora_id][experts_id]
) )
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_b[:, start_idx:end_idx, :]
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int): def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0.""" """Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices): for pos in range(self._w13_slices):
@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self.w2_lora_b_stacked[0][index] = 0 self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0 self.adapter_enabled[index] = 0
#
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
lora_b: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor],
): ):
"""Overwrites lora tensors at index.""" """Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list) assert isinstance(lora_a, list)
assert isinstance(lora_b, list) assert isinstance(lora_b, list)
self.reset_lora(index) self.reset_lora(index)
self.adapter_enabled[index] = 1 self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts num_experts = self.w13_lora_a_stacked[0].shape[1]
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
if self.tp_size > 1: w1_lora_a, w2_lora_a, w3_lora_a = lora_a
shard_size = self.base_layer.intermediate_size_per_partition w1_lora_b, w2_lora_b, w3_lora_b = lora_b
start_idx = self.tp_rank * shard_size assert (
end_idx = (self.tp_rank + 1) * shard_size num_experts
== w1_lora_a.shape[0]
== w2_lora_a.shape[0]
== w3_lora_a.shape[0]
)
w1_lora_b = w1_lora_b[start_idx:end_idx, :] slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
w3_lora_b = w3_lora_b[start_idx:end_idx, :] slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
w2_lora_a = w2_lora_a[:, start_idx:end_idx] slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
if self.fully_sharded: sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
# Based on S-LoRA, we slice W1 and W3 A along the rank dim, sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
# and W2 B along the hidden_size dim.
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0] self.w13_lora_a_stacked[0][
w2_start_idx = self.tp_rank * w2_shard_size index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
w2_end_idx = (self.tp_rank + 1) * w2_shard_size ].copy_(slliced_w1_lora_a, non_blocking=True)
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
# w1 lora_a
self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
# w3 lora_a
self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
# w1 lora_b self.w13_lora_a_stacked[1][
self.w13_lora_b_stacked[0][ index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] ].copy_(slliced_w3_lora_a, non_blocking=True)
].copy_(w1_lora_b, non_blocking=True)
# w3 lora_b
self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][ self.w13_lora_b_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(w2_lora_a, non_blocking=True) ].copy_(slliced_w1_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][ self.w13_lora_b_stacked[1][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(w2_lora_b, non_blocking=True) ].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs) return self.base_layer.forward(*args, **kwargs)
@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2 # source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA): class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> None: ) -> None:
"""Initializes lora matrices.""" """Initializes lora matrices."""
assert isinstance(model_config, PretrainedConfig)
self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras self.fully_sharded = lora_config.fully_sharded_loras
@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
self._create_lora_a_weights(max_loras, lora_config) self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config) self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor: def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
if self.tp_size == 1: if self.tp_size == 1:
return w13_lora_b return w13_lora_b
@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
shard_size = self.base_layer.intermediate_size_per_partition shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
if is_interleave: # HACK: Currently, only GPT-OSS is in interleaved order
if self._base_model == "GptOssForCausalLM":
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj) # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed. # in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :] w1_lora_b = w13_lora_b[:, ::2, :]
@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1) return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
w2_lora_b = w2_lora_b.permute(1, 0, 2) w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a) sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True) sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a) sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b) sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer.""" """Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1 return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

View File

@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False

View File

@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is ReplicatedLinear return type(source_layer) is ReplicatedLinear

View File

@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is RowParallelLinear return type(source_layer) is RowParallelLinear
@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
# specifying kwargs so they can be easily accessed in decorator # specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer( return super().can_replace_layer(

View File

@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: PretrainedConfig | None, model_config: PretrainedConfig | None = None,
) -> bool: ) -> bool:
return type(source_layer) is VocabParallelEmbedding return type(source_layer) is VocabParallelEmbedding

View File

@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
) )
return obj return obj
@classmethod
def pack_moe(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
assert first_lora is not None
rank = first_lora.rank
lora_alpha = first_lora.lora_alpha
assert len(loras) % 3 == 0
w1_lora_a_lst = []
w2_lora_a_lst = []
w3_lora_a_lst = []
w1_lora_b_lst = []
w2_lora_b_lst = []
w3_lora_b_lst = []
# TODO: Consider the case where some experts don't have LoRA added.
for eid in range(len(loras) // 3):
w1_lora = loras[eid * 3]
w2_lora = loras[eid * 3 + 1]
w3_lora = loras[eid * 3 + 2]
assert w1_lora is not None
assert w2_lora is not None
assert w3_lora is not None
w1_lora_a_lst.append(w1_lora.lora_a)
w2_lora_a_lst.append(w2_lora.lora_a)
w3_lora_a_lst.append(w3_lora.lora_a)
w1_lora_b_lst.append(w1_lora.lora_b)
w2_lora_b_lst.append(w2_lora.lora_b)
w3_lora_b_lst.append(w3_lora.lora_b)
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
obj = cls(
module_name,
rank,
[lora_alpha, lora_alpha, lora_alpha],
[w1_lora_a, w2_lora_a, w3_lora_a],
[w1_lora_b, w2_lora_b, w3_lora_b],
)
return obj
def optimize(self) -> "PackedLoRALayerWeights": def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b.""" """Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)): for i in range(len(self.lora_b)):

View File

@ -13,7 +13,7 @@ from torch import nn
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.punica_wrapper import get_punica_wrapper
@ -151,16 +151,13 @@ class LoRAModel:
if pin_memory: if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, peft_helper.r, loras) return cls(lora_model_id, peft_helper.r, loras)
@classmethod @classmethod
def from_local_checkpoint( def from_local_checkpoint(
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: list[str], expected_lora_modules: set[str],
peft_helper: PEFTHelper, peft_helper: PEFTHelper,
*, *,
lora_model_id: int | None = None, lora_model_id: int | None = None,
@ -190,10 +187,7 @@ class LoRAModel:
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
# new_embeddings_tensor_path = os.path.join(
# lora_dir, "new_embeddings.safetensors"
# )
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {} tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = [] unexpected_modules: list[list[str] | str] = []
@ -201,18 +195,19 @@ class LoRAModel:
for lora_module in modules.keys(): # noqa for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module): if is_base_embeddding_weights(lora_module):
continue continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) # Handle PEFT file format where experts.base_layer is the
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj # gate_up_proj and experts is the down_proj
if "base_layer" in lora_module: if "base_layer" in lora_module:
continue continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights # Case for expert lora weights
if ".experts" in module_name: if ".experts" in module_name:
if not any( expert_idx = module_name.find(".experts")
module_name.endswith(ele) for ele in expected_lora_modules expert_suffix = module_name[expert_idx + 1 :]
): if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)
if unexpected_modules: if unexpected_modules:
@ -358,9 +353,7 @@ class LoRAModelManager:
self.modules: dict[str, BaseLayerWithLoRA] = {} self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache. # Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None self._last_mapping: LoRAMapping | None = None
self._is_3d_moe_model = is_moe_model(self.model) and hasattr( self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
self.model, "is_3d_moe_weight"
)
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self self.model.lora_manager = self
@ -411,7 +404,7 @@ class LoRAModelManager:
continue continue
# Note (gnovack) - If MOE lora weights are not split into # Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here # num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor( if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
module_lora.lora_a module_lora.lora_a
): ):
# Handle PEFT file format where experts.base_layer is the # Handle PEFT file format where experts.base_layer is the
@ -679,7 +672,10 @@ class LoRAModelManager:
"cpu", "cpu",
) )
subloras.append(lora) subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras) if module.__class__.__name__ == "FusedMoEWithLoRA":
lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
else:
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora model.loras[module_name] = lora
return model return model
@ -739,13 +735,21 @@ class LoRAModelManager:
replaced_module_name = module_name.replace("model.", "") replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name): if lora_model.check_lora_name(module_name):
module_name = replaced_module_name module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( if module_name.endswith(".experts"):
replacement_loras lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
) replacement_loras, module_name
)
else:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras
)
# Remove the modules that have been replaced. # Remove the modules that have been replaced.
for module in replaced_module: for module in replaced_module:
lora_model.loras.pop(module, None) lora_model.loras.pop(module, None)
for lora in lora_model.loras.values():
lora.optimize()
def _get_lora_layer_weights( def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None: ) -> LoRALayerWeights | None:

View File

@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
def is_base_embeddding_weights(name: str) -> bool: def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights # hardcoded subfixes for input & output embedding weights
input_embedding_subfix = ".embed_tokens.base_layer.weight" embedding_suffixes = (
output_embedding_subfix = ".lm_head.base_layer.weight" ".embed_tokens.base_layer.weight",
".lm_head.base_layer.weight",
return name.endswith(input_embedding_subfix) or name.endswith(
output_embedding_subfix
) )
return name.endswith(embedding_suffixes)
def is_regex_target_modules( def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str] load_modules: str | list[str], expected_lora_modules: set[str]
) -> bool: ) -> bool:
""" """
PEFT supports passing `target_modules` in the form of regular expressions, PEFT supports passing `target_modules` in the form of regular expressions,
@ -195,8 +194,8 @@ def is_regex_target_modules(
except re.error: except re.error:
return False return False
def is_subset(sub_list, full_list): def is_subset(sub_list, full_set):
return set(sub_list).issubset(set(full_list)) return set(sub_list).issubset(full_set)
# Similar to PEFT's processing logic, regex-related operations are only # Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`. # executed when the load_modules is a `str`.
@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
# the expert indices are expanded based on the configured number # the expert indices are expanded based on the configured number
# of routed experts. # of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model) packed_modules_mapping = get_packed_modules_mapping(model)
if not hasattr(model, "is_3d_moe_weight"): if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping` # 3D MoE LoRA does not need `packed_modules_mapping`
packed_modules_mapping["experts"] = [ packed_modules_mapping["experts"] = [
weight_name.rstrip(".") weight_name.rstrip(".")

View File

@ -88,15 +88,15 @@ class WorkerLoRAManager:
try: try:
supported_lora_modules = self._adapter_manager.supported_lora_modules supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: list[str] = [] expected_lora_lst: list[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module]) expected_lora_lst.extend(packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_lst.append(module)
if module == "experts": if module == "experts":
expected_lora_modules.append(module) expected_lora_lst.append(module)
expected_lora_modules = list(set(expected_lora_modules)) expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir( peft_helper = PEFTHelper.from_local_dir(

View File

@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
MRO of your model class. MRO of your model class.
""" """
is_3d_moe_weight: ClassVar[bool] = False
# The `embedding_module` and `embedding_padding_modules` # The `embedding_module` and `embedding_padding_modules`
# are empty by default. # are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {} embedding_modules: ClassVar[dict[str, str]] = {}

View File

@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
class Qwen3VLMoeForConditionalGeneration( class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
): ):
is_3d_moe_weight: bool = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",